piyush333 commited on
Commit
dc8459d
·
verified ·
1 Parent(s): 7bbd9ff

Upload MultiTaskBiasModel (Mach-1, DPO epoch 5)

Browse files
Files changed (2) hide show
  1. model_dpo_epoch_5.pt +3 -0
  2. modeling_multitask_bias.py +25 -0
model_dpo_epoch_5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99020debae15f6e6ed1aeca19ab7c369bea6de797c6a5acca678b031e8c8910a
3
+ size 742526859
modeling_multitask_bias.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import DebertaV2Model
4
+
5
+ class MultiTaskBiasModel(nn.Module):
6
+ def __init__(self, model_name='microsoft/deberta-v3-base'):
7
+ super().__init__()
8
+ self.bert = DebertaV2Model.from_pretrained(model_name)
9
+ hidden = self.bert.config.hidden_size
10
+ self.heads = nn.ModuleDict({
11
+ task: nn.Sequential(
12
+ nn.Linear(hidden, hidden),
13
+ nn.ReLU(),
14
+ nn.Dropout(0.2),
15
+ nn.Linear(hidden, 3)
16
+ )
17
+ for task in ['political', 'gender', 'immigration']
18
+ })
19
+
20
+ def forward(self, input_ids, attention_mask, tasks):
21
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
22
+ logits = []
23
+ for i in range(len(tasks)):
24
+ logits.append(self.heads[tasks[i]](outputs[i].unsqueeze(0)))
25
+ return torch.cat(logits, dim=0)