|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig |
|
import torch |
|
import numpy as np |
|
import logging |
|
|
|
class PaperClassifier: |
|
|
|
AVAILABLE_MODELS = { |
|
'distilbert': { |
|
'name': 'distilbert-base-cased', |
|
'max_length': 512, |
|
'description': 'Lightweight and fast model, good for testing', |
|
'force_slow': False, |
|
'tokenizer_class': None |
|
}, |
|
'deberta-v3': { |
|
'name': 'microsoft/deberta-v3-base', |
|
'max_length': 512, |
|
'description': 'Advanced model with better performance', |
|
'force_slow': True, |
|
'tokenizer_class': 'DebertaV2TokenizerFast' |
|
}, |
|
't5': { |
|
'name': 'google/t5-v1_1-base', |
|
'max_length': 512, |
|
'description': 'Versatile text-to-text model', |
|
'force_slow': False |
|
}, |
|
'roberta': { |
|
'name': 'roberta-base', |
|
'max_length': 512, |
|
'description': 'Advanced model with strong performance', |
|
'force_slow': False, |
|
'tokenizer_class': None |
|
}, |
|
'scibert': { |
|
'name': 'allenai/scibert_scivocab_uncased', |
|
'max_length': 512, |
|
'description': 'Specialized for scientific text', |
|
'force_slow': False, |
|
'tokenizer_class': None |
|
}, |
|
'bert': { |
|
'name': 'bert-base-uncased', |
|
'max_length': 512, |
|
'description': 'Classic BERT model, good all-round performance', |
|
'force_slow': False, |
|
'tokenizer_class': None |
|
} |
|
} |
|
|
|
def __init__(self, model_type='distilbert'): |
|
""" |
|
Initialize the classifier with a specific model type |
|
|
|
Args: |
|
model_type (str): One of 'distilbert', 'deberta-v3', 't5', 'roberta', 'scibert' |
|
""" |
|
if model_type not in self.AVAILABLE_MODELS: |
|
raise ValueError(f"Model type must be one of {list(self.AVAILABLE_MODELS.keys())}") |
|
|
|
self.model_type = model_type |
|
self.model_config = self.AVAILABLE_MODELS[model_type] |
|
self.model_name = self.model_config['name'] |
|
|
|
|
|
self.categories = [ |
|
"cs", |
|
"math", |
|
"physics", |
|
"q-bio", |
|
"q-fin", |
|
"stat", |
|
"eess", |
|
"econ" |
|
] |
|
|
|
|
|
self.category_names = { |
|
"cs": "Computer Science", |
|
"math": "Mathematics", |
|
"physics": "Physics", |
|
"q-bio": "Biology", |
|
"q-fin": "Finance", |
|
"stat": "Statistics", |
|
"eess": "Electrical Engineering", |
|
"econ": "Economics" |
|
} |
|
|
|
|
|
self._initialize_tokenizer() |
|
|
|
|
|
self._initialize_model() |
|
|
|
|
|
print(f"Initialized {model_type} model: {self.model_name}") |
|
print(f"Description: {self.model_config['description']}") |
|
print("Note: This model needs to be fine-tuned on ArXiv data for accurate predictions.") |
|
|
|
def _initialize_tokenizer(self): |
|
"""Initialize the tokenizer with proper error handling""" |
|
try: |
|
|
|
config = AutoConfig.from_pretrained(self.model_name) |
|
|
|
|
|
if self.model_config['tokenizer_class']: |
|
from transformers import DebertaV2TokenizerFast |
|
self.tokenizer = DebertaV2TokenizerFast.from_pretrained( |
|
self.model_name, |
|
model_max_length=self.model_config['max_length'] |
|
) |
|
else: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, |
|
model_max_length=self.model_config['max_length'], |
|
use_fast=not self.model_config['force_slow'], |
|
trust_remote_code=True |
|
) |
|
|
|
print(f"Successfully initialized tokenizer for {self.model_type}") |
|
|
|
except Exception as e: |
|
print(f"Error initializing tokenizer: {str(e)}") |
|
print("Falling back to basic tokenizer...") |
|
|
|
|
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, |
|
use_fast=False, |
|
trust_remote_code=True |
|
) |
|
except Exception as e: |
|
|
|
print("Falling back to BERT tokenizer...") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
'bert-base-uncased', |
|
model_max_length=self.model_config['max_length'] |
|
) |
|
|
|
def _initialize_model(self): |
|
"""Initialize the model with proper error handling""" |
|
try: |
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
self.model_name, |
|
num_labels=len(self.categories), |
|
id2label={i: label for i, label in enumerate(self.categories)}, |
|
label2id={label: i for i, label in enumerate(self.categories)}, |
|
trust_remote_code=True |
|
) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to initialize model: {str(e)}") |
|
|
|
@classmethod |
|
def list_available_models(cls): |
|
"""List all available models with their descriptions""" |
|
print("Available models:") |
|
for model_type, config in cls.AVAILABLE_MODELS.items(): |
|
print(f"\n{model_type}:") |
|
print(f" Model: {config['name']}") |
|
print(f" Description: {config['description']}") |
|
|
|
def preprocess_text(self, title, abstract=None): |
|
""" |
|
Preprocess title and abstract |
|
|
|
Args: |
|
title (str): Paper title |
|
abstract (str, optional): Paper abstract |
|
""" |
|
if abstract: |
|
text = f"Title: {title}\nAbstract: {abstract}" |
|
else: |
|
text = f"Title: {title}" |
|
|
|
max_length = self.model_config['max_length'] |
|
|
|
if self.model_type == 't5': |
|
text = "classify: " + text |
|
|
|
return text[:max_length] |
|
|
|
def get_top_categories(self, probabilities, threshold=0.95): |
|
""" |
|
Get top categories that sum up to the threshold |
|
|
|
Args: |
|
probabilities (torch.Tensor): Model predictions |
|
threshold (float): Probability threshold (default: 0.95) |
|
|
|
Returns: |
|
list: List of (category, probability) tuples |
|
""" |
|
|
|
probs = probabilities.numpy() |
|
|
|
|
|
sorted_indices = np.argsort(probs)[::-1] |
|
|
|
|
|
cumsum = np.cumsum(probs[sorted_indices]) |
|
|
|
|
|
mask = cumsum <= threshold |
|
if not any(mask): |
|
mask[0] = True |
|
|
|
|
|
selected_indices = sorted_indices[mask] |
|
|
|
|
|
return [ |
|
{ |
|
'category': self.category_names.get(self.categories[idx], self.categories[idx]), |
|
'arxiv_category': self.categories[idx], |
|
'probability': float(probs[idx]) |
|
} |
|
for idx in selected_indices |
|
] |
|
|
|
def classify_paper(self, title, abstract=None): |
|
""" |
|
Classify a paper based on its title and optional abstract |
|
|
|
Args: |
|
title (str): Paper title |
|
abstract (str, optional): Paper abstract |
|
""" |
|
|
|
processed_text = self.preprocess_text(title, abstract) |
|
|
|
|
|
inputs = self.tokenizer( |
|
processed_text, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=self.model_config['max_length'], |
|
padding=True |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
predictions = torch.softmax(outputs.logits, dim=1)[0] |
|
|
|
|
|
top_categories = self.get_top_categories(predictions) |
|
|
|
|
|
return { |
|
'top_categories': top_categories, |
|
'model_used': self.model_type, |
|
'input_type': 'title_and_abstract' if abstract else 'title_only' |
|
} |
|
|
|
def train_on_arxiv(self, train_texts, train_labels, validation_texts=None, validation_labels=None, |
|
epochs=3, batch_size=16, learning_rate=2e-5): |
|
""" |
|
Function to fine-tune the model on ArXiv data |
|
|
|
Args: |
|
train_texts (list): List of paper texts (title + abstract) |
|
train_labels (list): List of corresponding ArXiv categories |
|
validation_texts (list, optional): Validation texts |
|
validation_labels (list, optional): Validation labels |
|
epochs (int): Number of training epochs |
|
batch_size (int): Training batch size |
|
learning_rate (float): Learning rate for training |
|
""" |
|
from transformers import TrainingArguments, Trainer |
|
import datasets |
|
|
|
|
|
train_encodings = self.tokenizer( |
|
train_texts, |
|
truncation=True, |
|
padding=True, |
|
max_length=self.model_config['max_length'] |
|
) |
|
|
|
|
|
train_label_ids = [self.categories.index(label) for label in train_labels] |
|
|
|
|
|
train_dataset = datasets.Dataset.from_dict({ |
|
'input_ids': train_encodings['input_ids'], |
|
'attention_mask': train_encodings['attention_mask'], |
|
'labels': train_label_ids |
|
}) |
|
|
|
|
|
if validation_texts and validation_labels: |
|
val_encodings = self.tokenizer( |
|
validation_texts, |
|
truncation=True, |
|
padding=True, |
|
max_length=self.model_config['max_length'] |
|
) |
|
val_label_ids = [self.categories.index(label) for label in validation_labels] |
|
validation_dataset = datasets.Dataset.from_dict({ |
|
'input_ids': val_encodings['input_ids'], |
|
'attention_mask': val_encodings['attention_mask'], |
|
'labels': val_label_ids |
|
}) |
|
else: |
|
validation_dataset = None |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=f"./results_{self.model_type}", |
|
num_train_epochs=epochs, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
warmup_steps=500, |
|
weight_decay=0.01, |
|
logging_dir=f"./logs_{self.model_type}", |
|
logging_steps=10, |
|
learning_rate=learning_rate, |
|
evaluation_strategy="epoch" if validation_dataset else "no", |
|
save_strategy="epoch", |
|
load_best_model_at_end=True if validation_dataset else False, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=validation_dataset, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
save_dir = f"./fine_tuned_{self.model_type}" |
|
self.model.save_pretrained(save_dir) |
|
self.tokenizer.save_pretrained(save_dir) |
|
print(f"Model saved to {save_dir}") |
|
|
|
@classmethod |
|
def load_fine_tuned(cls, model_type, model_path): |
|
""" |
|
Load a fine-tuned model from disk |
|
|
|
Args: |
|
model_type (str): The type of model that was fine-tuned |
|
model_path (str): Path to the saved model |
|
""" |
|
classifier = cls(model_type) |
|
classifier.model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
classifier.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
return classifier |