File size: 942 Bytes
dc8459d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
from transformers import DebertaV2Model

class MultiTaskBiasModel(nn.Module):
    def __init__(self, model_name='microsoft/deberta-v3-base'):
        super().__init__()
        self.bert = DebertaV2Model.from_pretrained(model_name)
        hidden = self.bert.config.hidden_size
        self.heads = nn.ModuleDict({
            task: nn.Sequential(
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden, 3)
            )
            for task in ['political', 'gender', 'immigration']
        })

    def forward(self, input_ids, attention_mask, tasks):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        logits = []
        for i in range(len(tasks)):
            logits.append(self.heads[tasks[i]](outputs[i].unsqueeze(0)))
        return torch.cat(logits, dim=0)