Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, AutoConfig, AutoModel | |
class CustomModel(PreTrainedModel): | |
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class | |
def __init__(self, config): | |
super().__init__(config) | |
# Implement your model architecture here | |
self.encoder = AutoModel.from_config(config) # Load the base model | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
def forward(self, input_ids, attention_mask=None): | |
# Pass inputs through the encoder | |
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
# Get the pooled output (e.g., CLS token for classification tasks) | |
pooled_output = outputs.last_hidden_state[:, 0, :] | |
# Pass through the classifier | |
logits = self.classifier(pooled_output) | |
return logits | |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
try: | |
# Load the configuration | |
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | |
# Initialize the model with the configuration | |
model = cls(config) | |
# Optionally, you can load the state_dict here if needed | |
# model.load_state_dict(torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin"))) | |
return model | |
except Exception as e: | |
print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}") | |
return None |