import torch.nn as nn from transformers import AutoModel class CustomMPRNAForSequenceClassification(nn.Module): def __init__(self, base_model, num_labels): super().__init__() self.base_model = base_model self.num_labels = num_labels self.classifier = nn.Linear(base_model.config.hidden_size, num_labels) self.dropout = nn.Dropout(0.1) def forward(self, input_ids, attention_mask=None, labels=None): outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs[0][:, 0, :] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return {"logits": logits, "loss": loss} if loss is not None else {"logits": logits}