|
```python |
|
class GPTRewardModel(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
model = AutoModelForCausalLM.from_pretrained("pvduy/vicuna-13b-v1.1") |
|
self.config = model.config |
|
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd |
|
self.transformer = model.model |
|
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) |
|
self.tokenizer = AutoTokenizer.from_pretrained("pvduy/vicuna-13b-v1.1") |
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
self.tokenizer.padding_side = "right" |
|
self.PAD_ID = self.tokenizer.pad_token_id |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
past_key_values=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
mc_token_ids=None, |
|
labels=None, |
|
return_dict=False, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
): |
|
loss = None |
|
transformer_outputs = self.transformer( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
|
|
rewards = self.v_head(hidden_states).squeeze(-1) |
|
ends = torch.argmax((input_ids == self.PAD_ID).type(torch.float32), dim=1).view(-1, 1) |
|
rewards = torch.gather(rewards, 1, ends) |
|
return rewards |
|
|
|
``` |