ysdaml4 / model.py
ssbars's picture
v2
12faaae
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch
import numpy as np
import logging
class PaperClassifier:
# Available models with their configurations
AVAILABLE_MODELS = {
'distilbert': {
'name': 'distilbert-base-cased',
'max_length': 512,
'description': 'Lightweight and fast model, good for testing',
'force_slow': False,
'tokenizer_class': None # Use default
},
'deberta-v3': {
'name': 'microsoft/deberta-v3-base',
'max_length': 512,
'description': 'Advanced model with better performance',
'force_slow': True, # Force slow tokenizer for DeBERTa
'tokenizer_class': 'DebertaV2TokenizerFast' # Specify tokenizer class
},
't5': {
'name': 'google/t5-v1_1-base',
'max_length': 512,
'description': 'Versatile text-to-text model',
'force_slow': False
},
'roberta': {
'name': 'roberta-base',
'max_length': 512,
'description': 'Advanced model with strong performance',
'force_slow': False,
'tokenizer_class': None # Use default
},
'scibert': {
'name': 'allenai/scibert_scivocab_uncased',
'max_length': 512,
'description': 'Specialized for scientific text',
'force_slow': False,
'tokenizer_class': None # Use default
},
'bert': {
'name': 'bert-base-uncased',
'max_length': 512,
'description': 'Classic BERT model, good all-round performance',
'force_slow': False,
'tokenizer_class': None # Use default
}
}
def __init__(self, model_type='distilbert'):
"""
Initialize the classifier with a specific model type
Args:
model_type (str): One of 'distilbert', 'deberta-v3', 't5', 'roberta', 'scibert'
"""
if model_type not in self.AVAILABLE_MODELS:
raise ValueError(f"Model type must be one of {list(self.AVAILABLE_MODELS.keys())}")
self.model_type = model_type
self.model_config = self.AVAILABLE_MODELS[model_type]
self.model_name = self.model_config['name']
# ArXiv main categories with descriptions
self.categories = [
"cs", # Computer Science
"math", # Mathematics
"physics", # Physics
"q-bio", # Quantitative Biology
"q-fin", # Quantitative Finance
"stat", # Statistics
"eess", # Electrical Engineering and Systems Science
"econ" # Economics
]
# Human readable category names
self.category_names = {
"cs": "Computer Science",
"math": "Mathematics",
"physics": "Physics",
"q-bio": "Biology",
"q-fin": "Finance",
"stat": "Statistics",
"eess": "Electrical Engineering",
"econ": "Economics"
}
# Initialize tokenizer with proper error handling
self._initialize_tokenizer()
# Initialize model with proper error handling
self._initialize_model()
# Print model info
print(f"Initialized {model_type} model: {self.model_name}")
print(f"Description: {self.model_config['description']}")
print("Note: This model needs to be fine-tuned on ArXiv data for accurate predictions.")
def _initialize_tokenizer(self):
"""Initialize the tokenizer with proper error handling"""
try:
# First try loading the tokenizer configuration
config = AutoConfig.from_pretrained(self.model_name)
# Try loading the tokenizer with specific class if specified
if self.model_config['tokenizer_class']:
from transformers import DebertaV2TokenizerFast
self.tokenizer = DebertaV2TokenizerFast.from_pretrained(
self.model_name,
model_max_length=self.model_config['max_length']
)
else:
# Try loading with AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
model_max_length=self.model_config['max_length'],
use_fast=not self.model_config['force_slow'],
trust_remote_code=True
)
print(f"Successfully initialized tokenizer for {self.model_type}")
except Exception as e:
print(f"Error initializing tokenizer: {str(e)}")
print("Falling back to basic tokenizer...")
# Try one more time with minimal settings
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_fast=False,
trust_remote_code=True
)
except Exception as e:
# If all else fails, try using BERT tokenizer as last resort
print("Falling back to BERT tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
'bert-base-uncased',
model_max_length=self.model_config['max_length']
)
def _initialize_model(self):
"""Initialize the model with proper error handling"""
try:
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=len(self.categories),
id2label={i: label for i, label in enumerate(self.categories)},
label2id={label: i for i, label in enumerate(self.categories)},
trust_remote_code=True # Allow custom code from hub
)
except Exception as e:
raise RuntimeError(f"Failed to initialize model: {str(e)}")
@classmethod
def list_available_models(cls):
"""List all available models with their descriptions"""
print("Available models:")
for model_type, config in cls.AVAILABLE_MODELS.items():
print(f"\n{model_type}:")
print(f" Model: {config['name']}")
print(f" Description: {config['description']}")
def preprocess_text(self, title, abstract=None):
"""
Preprocess title and abstract
Args:
title (str): Paper title
abstract (str, optional): Paper abstract
"""
if abstract:
text = f"Title: {title}\nAbstract: {abstract}"
else:
text = f"Title: {title}"
max_length = self.model_config['max_length']
if self.model_type == 't5':
text = "classify: " + text
return text[:max_length]
def get_top_categories(self, probabilities, threshold=0.95):
"""
Get top categories that sum up to the threshold
Args:
probabilities (torch.Tensor): Model predictions
threshold (float): Probability threshold (default: 0.95)
Returns:
list: List of (category, probability) tuples
"""
# Convert to numpy for easier manipulation
probs = probabilities.numpy()
# Sort indices by probability
sorted_indices = np.argsort(probs)[::-1]
# Calculate cumulative sum
cumsum = np.cumsum(probs[sorted_indices])
# Find how many categories we need to reach the threshold
mask = cumsum <= threshold
if not any(mask): # If first probability is already > threshold
mask[0] = True
# Get the selected indices
selected_indices = sorted_indices[mask]
# Return categories and their probabilities
return [
{
'category': self.category_names.get(self.categories[idx], self.categories[idx]),
'arxiv_category': self.categories[idx],
'probability': float(probs[idx])
}
for idx in selected_indices
]
def classify_paper(self, title, abstract=None):
"""
Classify a paper based on its title and optional abstract
Args:
title (str): Paper title
abstract (str, optional): Paper abstract
"""
# Preprocess the text
processed_text = self.preprocess_text(title, abstract)
# Tokenize
inputs = self.tokenizer(
processed_text,
return_tensors="pt",
truncation=True,
max_length=self.model_config['max_length'],
padding=True
)
# Get model predictions
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.softmax(outputs.logits, dim=1)[0]
# Get top categories that sum to 95% probability
top_categories = self.get_top_categories(predictions)
# Return predictions
return {
'top_categories': top_categories,
'model_used': self.model_type,
'input_type': 'title_and_abstract' if abstract else 'title_only'
}
def train_on_arxiv(self, train_texts, train_labels, validation_texts=None, validation_labels=None,
epochs=3, batch_size=16, learning_rate=2e-5):
"""
Function to fine-tune the model on ArXiv data
Args:
train_texts (list): List of paper texts (title + abstract)
train_labels (list): List of corresponding ArXiv categories
validation_texts (list, optional): Validation texts
validation_labels (list, optional): Validation labels
epochs (int): Number of training epochs
batch_size (int): Training batch size
learning_rate (float): Learning rate for training
"""
from transformers import TrainingArguments, Trainer
import datasets
# Prepare datasets
train_encodings = self.tokenizer(
train_texts,
truncation=True,
padding=True,
max_length=self.model_config['max_length']
)
# Convert labels to ids
train_label_ids = [self.categories.index(label) for label in train_labels]
# Create training dataset
train_dataset = datasets.Dataset.from_dict({
'input_ids': train_encodings['input_ids'],
'attention_mask': train_encodings['attention_mask'],
'labels': train_label_ids
})
# Create validation dataset if provided
if validation_texts and validation_labels:
val_encodings = self.tokenizer(
validation_texts,
truncation=True,
padding=True,
max_length=self.model_config['max_length']
)
val_label_ids = [self.categories.index(label) for label in validation_labels]
validation_dataset = datasets.Dataset.from_dict({
'input_ids': val_encodings['input_ids'],
'attention_mask': val_encodings['attention_mask'],
'labels': val_label_ids
})
else:
validation_dataset = None
# Training arguments
training_args = TrainingArguments(
output_dir=f"./results_{self.model_type}",
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=500,
weight_decay=0.01,
logging_dir=f"./logs_{self.model_type}",
logging_steps=10,
learning_rate=learning_rate,
evaluation_strategy="epoch" if validation_dataset else "no",
save_strategy="epoch",
load_best_model_at_end=True if validation_dataset else False,
)
# Initialize trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
)
# Train the model
trainer.train()
# Save the fine-tuned model
save_dir = f"./fine_tuned_{self.model_type}"
self.model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)
print(f"Model saved to {save_dir}")
@classmethod
def load_fine_tuned(cls, model_type, model_path):
"""
Load a fine-tuned model from disk
Args:
model_type (str): The type of model that was fine-tuned
model_path (str): Path to the saved model
"""
classifier = cls(model_type)
classifier.model = AutoModelForSequenceClassification.from_pretrained(model_path)
classifier.tokenizer = AutoTokenizer.from_pretrained(model_path)
return classifier