| import pytorch_lightning as pl | |
| import torch | |
| from transformers.optimization import AdamW | |
| import torchmetrics | |
| class DualEncoderModule(pl.LightningModule): | |
| def __init__(self, tokenizer, model, learning_rate=1e-3): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.learning_rate = learning_rate | |
| self.train_acc = torchmetrics.Accuracy( | |
| task="multiclass", num_classes=model.num_labels | |
| ) | |
| self.val_acc = torchmetrics.Accuracy( | |
| task="multiclass", num_classes=model.num_labels | |
| ) | |
| self.test_acc = torchmetrics.Accuracy( | |
| task="multiclass", num_classes=model.num_labels | |
| ) | |
| def forward(self, input_ids, **kwargs): | |
| return self.model(input_ids, **kwargs) | |
| def configure_optimizers(self): | |
| optimizer = AdamW(self.parameters(), lr=self.learning_rate) | |
| return optimizer | |
| def training_step(self, batch, batch_idx): | |
| pos_ids, pos_mask, neg_ids, neg_mask = batch | |
| neg_ids = neg_ids.view(-1, neg_ids.shape[-1]) | |
| neg_mask = neg_mask.view(-1, neg_mask.shape[-1]) | |
| pos_outputs = self( | |
| pos_ids, | |
| attention_mask=pos_mask, | |
| labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to( | |
| pos_ids.get_device() | |
| ), | |
| ) | |
| neg_outputs = self( | |
| neg_ids, | |
| attention_mask=neg_mask, | |
| labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to( | |
| neg_ids.get_device() | |
| ), | |
| ) | |
| loss_scale = 1.0 | |
| loss = pos_outputs.loss + loss_scale * neg_outputs.loss | |
| pos_logits = pos_outputs.logits | |
| pos_preds = torch.argmax(pos_logits, axis=1) | |
| self.train_acc( | |
| pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu() | |
| ) | |
| neg_logits = neg_outputs.logits | |
| neg_preds = torch.argmax(neg_logits, axis=1) | |
| self.train_acc( | |
| neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu() | |
| ) | |
| return {"loss": loss} | |
| def validation_step(self, batch, batch_idx): | |
| pos_ids, pos_mask, neg_ids, neg_mask = batch | |
| neg_ids = neg_ids.view(-1, neg_ids.shape[-1]) | |
| neg_mask = neg_mask.view(-1, neg_mask.shape[-1]) | |
| pos_outputs = self( | |
| pos_ids, | |
| attention_mask=pos_mask, | |
| labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to( | |
| pos_ids.get_device() | |
| ), | |
| ) | |
| neg_outputs = self( | |
| neg_ids, | |
| attention_mask=neg_mask, | |
| labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to( | |
| neg_ids.get_device() | |
| ), | |
| ) | |
| loss_scale = 1.0 | |
| loss = pos_outputs.loss + loss_scale * neg_outputs.loss | |
| pos_logits = pos_outputs.logits | |
| pos_preds = torch.argmax(pos_logits, axis=1) | |
| self.val_acc( | |
| pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu() | |
| ) | |
| neg_logits = neg_outputs.logits | |
| neg_preds = torch.argmax(neg_logits, axis=1) | |
| self.val_acc( | |
| neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu() | |
| ) | |
| self.log("val_acc", self.val_acc) | |
| return {"loss": loss} | |
| def test_step(self, batch, batch_idx): | |
| pos_ids, pos_mask, neg_ids, neg_mask = batch | |
| neg_ids = neg_ids.view(-1, neg_ids.shape[-1]) | |
| neg_mask = neg_mask.view(-1, neg_mask.shape[-1]) | |
| pos_outputs = self( | |
| pos_ids, | |
| attention_mask=pos_mask, | |
| labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to( | |
| pos_ids.get_device() | |
| ), | |
| ) | |
| neg_outputs = self( | |
| neg_ids, | |
| attention_mask=neg_mask, | |
| labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to( | |
| neg_ids.get_device() | |
| ), | |
| ) | |
| pos_logits = pos_outputs.logits | |
| pos_preds = torch.argmax(pos_logits, axis=1) | |
| self.test_acc( | |
| pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu() | |
| ) | |
| neg_logits = neg_outputs.logits | |
| neg_preds = torch.argmax(neg_logits, axis=1) | |
| self.test_acc( | |
| neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu() | |
| ) | |
| self.log("test_acc", self.test_acc) | |