import torch | |
import torch.nn as nn | |
from transformers import DebertaV2Model | |
class MultiTaskBiasModel(nn.Module): | |
def __init__(self, model_name='microsoft/deberta-v3-base'): | |
super().__init__() | |
self.bert = DebertaV2Model.from_pretrained(model_name) | |
hidden = self.bert.config.hidden_size | |
self.heads = nn.ModuleDict({ | |
task: nn.Sequential( | |
nn.Linear(hidden, hidden), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(hidden, 3) | |
) | |
for task in ['political', 'gender', 'immigration'] | |
}) | |
def forward(self, input_ids, attention_mask, tasks): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
logits = [] | |
for i in range(len(tasks)): | |
logits.append(self.heads[tasks[i]](outputs[i].unsqueeze(0))) | |
return torch.cat(logits, dim=0) | |