|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import numpy as np |
|
from transformers import ( |
|
AutoTokenizer, |
|
is_torch_npu_available, |
|
AutoModel, |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
AutoConfig, |
|
BertModel, |
|
BertConfig |
|
) |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from typing import Union, List, Optional |
|
|
|
|
|
class ListConRankerConfig(PretrainedConfig): |
|
"""Configuration class for ListConRanker model.""" |
|
|
|
model_type = "listconranker" |
|
|
|
def __init__( |
|
self, |
|
list_transformer_layers: int = 2, |
|
num_attention_heads: int = 8, |
|
hidden_size: int = 1792, |
|
base_hidden_size: int = 1024, |
|
num_labels: int = 1, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.list_transformer_layers = list_transformer_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.hidden_size = hidden_size |
|
self.base_hidden_size = base_hidden_size |
|
self.num_labels = num_labels |
|
|
|
self.bert_config = BertConfig(**kwargs) |
|
self.bert_config.output_hidden_states = True |
|
|
|
class QueryEmbedding(nn.Module): |
|
def __init__(self, config) -> None: |
|
super().__init__() |
|
self.query_embedding = nn.Embedding(2, config.hidden_size) |
|
self.layerNorm = nn.LayerNorm(config.hidden_size) |
|
|
|
def forward(self, x, tags): |
|
query_embeddings = self.query_embedding(tags) |
|
x += query_embeddings |
|
x = self.layerNorm(x) |
|
return x |
|
|
|
class ListTransformer(nn.Module): |
|
def __init__(self, num_layer, config) -> None: |
|
super().__init__() |
|
self.config = config |
|
self.list_transformer_layer = nn.TransformerEncoderLayer(1792, self.config.num_attention_heads, batch_first=True, activation=F.gelu, norm_first=False) |
|
self.list_transformer = nn.TransformerEncoder(self.list_transformer_layer, num_layer) |
|
self.relu = nn.ReLU() |
|
self.query_embedding = QueryEmbedding(config) |
|
|
|
self.linear_score3 = nn.Linear(config.hidden_size * 2, config.hidden_size) |
|
self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size) |
|
self.linear_score1 = nn.Linear(config.hidden_size * 2, 1) |
|
|
|
def forward(self, pair_features, pair_nums): |
|
pair_nums = [x + 1 for x in pair_nums] |
|
batch_pair_features = pair_features.split(pair_nums) |
|
|
|
pair_feature_query_passage_concat_list = [] |
|
for i in range(len(batch_pair_features)): |
|
pair_feature_query = batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1) |
|
pair_feature_passage = batch_pair_features[i][1:] |
|
pair_feature_query_passage_concat_list.append(torch.cat([pair_feature_query, pair_feature_passage], dim=1)) |
|
pair_feature_query_passage_concat = torch.cat(pair_feature_query_passage_concat_list, dim=0) |
|
|
|
batch_pair_features = nn.utils.rnn.pad_sequence(batch_pair_features, batch_first=True) |
|
|
|
query_embedding_tags = torch.zeros(batch_pair_features.size(0), batch_pair_features.size(1), dtype=torch.long, device=self.device) |
|
query_embedding_tags[:, 0] = 1 |
|
batch_pair_features = self.query_embedding(batch_pair_features, query_embedding_tags) |
|
|
|
mask = self.generate_attention_mask(pair_nums) |
|
query_mask = self.generate_attention_mask_custom(pair_nums) |
|
pair_list_features = self.list_transformer(batch_pair_features, src_key_padding_mask=mask, mask=query_mask) |
|
|
|
output_pair_list_features = [] |
|
output_query_list_features = [] |
|
pair_features_after_transformer_list = [] |
|
for idx, pair_num in enumerate(pair_nums): |
|
output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :]) |
|
output_query_list_features.append(pair_list_features[idx, 0, :]) |
|
pair_features_after_transformer_list.append(pair_list_features[idx, :pair_num, :]) |
|
|
|
pair_features_after_transformer_cat_query_list = [] |
|
for idx, pair_num in enumerate(pair_nums): |
|
query_ft = output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1) |
|
pair_features_after_transformer_cat_query = torch.cat([query_ft, output_pair_list_features[idx]], dim=1) |
|
pair_features_after_transformer_cat_query_list.append(pair_features_after_transformer_cat_query) |
|
pair_features_after_transformer_cat_query = torch.cat(pair_features_after_transformer_cat_query_list, dim=0) |
|
|
|
pair_feature_query_passage_concat = self.relu(self.linear_score2(pair_feature_query_passage_concat)) |
|
pair_features_after_transformer_cat_query = self.relu(self.linear_score3(pair_features_after_transformer_cat_query)) |
|
final_ft = torch.cat([pair_feature_query_passage_concat, pair_features_after_transformer_cat_query], dim=1) |
|
logits = self.linear_score1(final_ft).squeeze() |
|
|
|
return logits, torch.cat(pair_features_after_transformer_list, dim=0) |
|
|
|
def generate_attention_mask(self, pair_num): |
|
max_len = max(pair_num) |
|
batch_size = len(pair_num) |
|
mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device) |
|
for i, length in enumerate(pair_num): |
|
mask[i, length:] = True |
|
return mask |
|
|
|
def generate_attention_mask_custom(self, pair_num): |
|
max_len = max(pair_num) |
|
mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device) |
|
mask[0, 1:] = True |
|
return mask |
|
|
|
|
|
class ListConRankerModel(PreTrainedModel): |
|
""" |
|
ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification. |
|
""" |
|
config_class = ListConRankerConfig |
|
base_model_prefix = "listconranker" |
|
|
|
def __init__(self, config: ListConRankerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.num_labels = config.num_labels |
|
self.hf_model = BertModel(config) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
self.linear_in_embedding = nn.Linear(config.base_hidden_size, config.hidden_size) |
|
self.list_transformer = ListTransformer( |
|
config.list_transformer_layers, |
|
config, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
pair_num: Optional[torch.Tensor] = None, |
|
**kwargs |
|
) -> Union[SequenceClassifierOutput, tuple]: |
|
|
|
if pair_num is not None: |
|
pair_nums = pair_num.tolist() |
|
else: |
|
|
|
batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) |
|
pair_nums = [1] * batch_size |
|
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
self.list_transformer.device = device |
|
|
|
|
|
if self.training: |
|
pass |
|
else: |
|
split_batch = 400 |
|
if sum(pair_nums) > split_batch: |
|
last_hidden_state_list = [] |
|
input_ids_list = input_ids.split(split_batch) |
|
attention_mask_list = attention_mask.split(split_batch) |
|
for i in range(len(input_ids_list)): |
|
last_hidden_state = self.hf_model( |
|
input_ids=input_ids_list[i], |
|
attention_mask=attention_mask_list[i], |
|
return_dict=True).hidden_states[-1] |
|
last_hidden_state_list.append(last_hidden_state) |
|
last_hidden_state = torch.cat(last_hidden_state_list, dim=0) |
|
else: |
|
ranker_out = self.hf_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
return_dict=True) |
|
last_hidden_state = ranker_out.last_hidden_state |
|
|
|
pair_features = self.average_pooling(last_hidden_state, attention_mask) |
|
pair_features = self.linear_in_embedding(pair_features) |
|
|
|
logits, pair_features_after_list_transformer = self.list_transformer(pair_features, pair_nums) |
|
logits = self.sigmoid(logits) |
|
|
|
return logits |
|
|
|
def average_pooling(self, hidden_state, attention_mask): |
|
extended_attention_mask = attention_mask.unsqueeze(-1).expand(hidden_state.size()).to(dtype=hidden_state.dtype) |
|
masked_hidden_state = hidden_state * extended_attention_mask |
|
sum_embeddings = torch.sum(masked_hidden_state, dim=1) |
|
sum_mask = extended_attention_mask.sum(dim=1) |
|
return sum_embeddings / sum_mask |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs): |
|
model = super().from_pretrained( |
|
model_name_or_path,config=config, **kwargs) |
|
|
|
|
|
linear_path = f"{model_name_or_path}/linear_in_embedding.pt" |
|
transformer_path = f"{model_name_or_path}/list_transformer.pt" |
|
|
|
try: |
|
model.linear_in_embedding.load_state_dict(torch.load(linear_path)) |
|
model.list_transformer.load_state_dict(torch.load(transformer_path)) |
|
except FileNotFoundError: |
|
print(f"Warning: Could not load custom weights from {model_name_or_path}") |
|
|
|
return model |