deepseekv3 / custom_model.py
sapthesh's picture
Update custom_model.py
c1e5327 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoConfig, AutoModel
class CustomModel(PreTrainedModel):
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
def __init__(self, config):
super().__init__(config)
# Implement your model architecture here
self.encoder = AutoModel.from_config(config) # Load the base model
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids, attention_mask=None):
# Pass inputs through the encoder
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Get the pooled output (e.g., CLS token for classification tasks)
pooled_output = outputs.last_hidden_state[:, 0, :]
# Pass through the classifier
logits = self.classifier(pooled_output)
return logits
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
try:
# Load the configuration
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
# Initialize the model with the configuration
model = cls(config)
# Optionally, you can load the state_dict here if needed
# model.load_state_dict(torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")))
return model
except Exception as e:
print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
return None