File size: 13,353 Bytes
12faaae
8a1304d
 
12faaae
8a1304d
 
12faaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a1304d
12faaae
8a1304d
12faaae
 
 
 
 
 
 
 
8a1304d
12faaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a1304d
12faaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a1304d
12faaae
 
 
 
 
 
 
 
8a1304d
12faaae
8a1304d
 
12faaae
 
 
 
 
 
 
8a1304d
 
 
 
12faaae
 
 
 
8a1304d
12faaae
8a1304d
12faaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch
import numpy as np
import logging

class PaperClassifier:
    # Available models with their configurations
    AVAILABLE_MODELS = {
        'distilbert': {
            'name': 'distilbert-base-cased',
            'max_length': 512,
            'description': 'Lightweight and fast model, good for testing',
            'force_slow': False,
            'tokenizer_class': None  # Use default
        },
        'deberta-v3': {
            'name': 'microsoft/deberta-v3-base',
            'max_length': 512,
            'description': 'Advanced model with better performance',
            'force_slow': True,  # Force slow tokenizer for DeBERTa
            'tokenizer_class': 'DebertaV2TokenizerFast'  # Specify tokenizer class
        },
        't5': {
            'name': 'google/t5-v1_1-base',
            'max_length': 512,
            'description': 'Versatile text-to-text model',
            'force_slow': False
        },
        'roberta': {
            'name': 'roberta-base',
            'max_length': 512,
            'description': 'Advanced model with strong performance',
            'force_slow': False,
            'tokenizer_class': None  # Use default
        },
        'scibert': {
            'name': 'allenai/scibert_scivocab_uncased',
            'max_length': 512,
            'description': 'Specialized for scientific text',
            'force_slow': False,
            'tokenizer_class': None  # Use default
        },
        'bert': {
            'name': 'bert-base-uncased',
            'max_length': 512,
            'description': 'Classic BERT model, good all-round performance',
            'force_slow': False,
            'tokenizer_class': None  # Use default
        }
    }

    def __init__(self, model_type='distilbert'):
        """
        Initialize the classifier with a specific model type
        
        Args:
            model_type (str): One of 'distilbert', 'deberta-v3', 't5', 'roberta', 'scibert'
        """
        if model_type not in self.AVAILABLE_MODELS:
            raise ValueError(f"Model type must be one of {list(self.AVAILABLE_MODELS.keys())}")
        
        self.model_type = model_type
        self.model_config = self.AVAILABLE_MODELS[model_type]
        self.model_name = self.model_config['name']
        
        # ArXiv main categories with descriptions
        self.categories = [
            "cs",      # Computer Science
            "math",    # Mathematics
            "physics", # Physics
            "q-bio",  # Quantitative Biology
            "q-fin",  # Quantitative Finance
            "stat",   # Statistics
            "eess",   # Electrical Engineering and Systems Science
            "econ"    # Economics
        ]
        
        # Human readable category names
        self.category_names = {
            "cs": "Computer Science",
            "math": "Mathematics",
            "physics": "Physics",
            "q-bio": "Biology",
            "q-fin": "Finance",
            "stat": "Statistics",
            "eess": "Electrical Engineering",
            "econ": "Economics"
        }
        
        # Initialize tokenizer with proper error handling
        self._initialize_tokenizer()
        
        # Initialize model with proper error handling
        self._initialize_model()
        
        # Print model info
        print(f"Initialized {model_type} model: {self.model_name}")
        print(f"Description: {self.model_config['description']}")
        print("Note: This model needs to be fine-tuned on ArXiv data for accurate predictions.")
    
    def _initialize_tokenizer(self):
        """Initialize the tokenizer with proper error handling"""
        try:
            # First try loading the tokenizer configuration
            config = AutoConfig.from_pretrained(self.model_name)
            
            # Try loading the tokenizer with specific class if specified
            if self.model_config['tokenizer_class']:
                from transformers import DebertaV2TokenizerFast
                self.tokenizer = DebertaV2TokenizerFast.from_pretrained(
                    self.model_name,
                    model_max_length=self.model_config['max_length']
                )
            else:
                # Try loading with AutoTokenizer
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_name,
                    model_max_length=self.model_config['max_length'],
                    use_fast=not self.model_config['force_slow'],
                    trust_remote_code=True
                )
            
            print(f"Successfully initialized tokenizer for {self.model_type}")
            
        except Exception as e:
            print(f"Error initializing tokenizer: {str(e)}")
            print("Falling back to basic tokenizer...")
            
            # Try one more time with minimal settings
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_name,
                    use_fast=False,
                    trust_remote_code=True
                )
            except Exception as e:
                # If all else fails, try using BERT tokenizer as last resort
                print("Falling back to BERT tokenizer...")
                self.tokenizer = AutoTokenizer.from_pretrained(
                    'bert-base-uncased',
                    model_max_length=self.model_config['max_length']
                )
    
    def _initialize_model(self):
        """Initialize the model with proper error handling"""
        try:
            self.model = AutoModelForSequenceClassification.from_pretrained(
                self.model_name,
                num_labels=len(self.categories),
                id2label={i: label for i, label in enumerate(self.categories)},
                label2id={label: i for i, label in enumerate(self.categories)},
                trust_remote_code=True  # Allow custom code from hub
            )
        except Exception as e:
            raise RuntimeError(f"Failed to initialize model: {str(e)}")
    
    @classmethod
    def list_available_models(cls):
        """List all available models with their descriptions"""
        print("Available models:")
        for model_type, config in cls.AVAILABLE_MODELS.items():
            print(f"\n{model_type}:")
            print(f"  Model: {config['name']}")
            print(f"  Description: {config['description']}")
    
    def preprocess_text(self, title, abstract=None):
        """
        Preprocess title and abstract
        
        Args:
            title (str): Paper title
            abstract (str, optional): Paper abstract
        """
        if abstract:
            text = f"Title: {title}\nAbstract: {abstract}"
        else:
            text = f"Title: {title}"
        
        max_length = self.model_config['max_length']
        
        if self.model_type == 't5':
            text = "classify: " + text
        
        return text[:max_length]
    
    def get_top_categories(self, probabilities, threshold=0.95):
        """
        Get top categories that sum up to the threshold
        
        Args:
            probabilities (torch.Tensor): Model predictions
            threshold (float): Probability threshold (default: 0.95)
        
        Returns:
            list: List of (category, probability) tuples
        """
        # Convert to numpy for easier manipulation
        probs = probabilities.numpy()
        
        # Sort indices by probability
        sorted_indices = np.argsort(probs)[::-1]
        
        # Calculate cumulative sum
        cumsum = np.cumsum(probs[sorted_indices])
        
        # Find how many categories we need to reach the threshold
        mask = cumsum <= threshold
        if not any(mask):  # If first probability is already > threshold
            mask[0] = True
        
        # Get the selected indices
        selected_indices = sorted_indices[mask]
        
        # Return categories and their probabilities
        return [
            {
                'category': self.category_names.get(self.categories[idx], self.categories[idx]),
                'arxiv_category': self.categories[idx],
                'probability': float(probs[idx])
            }
            for idx in selected_indices
        ]
    
    def classify_paper(self, title, abstract=None):
        """
        Classify a paper based on its title and optional abstract
        
        Args:
            title (str): Paper title
            abstract (str, optional): Paper abstract
        """
        # Preprocess the text
        processed_text = self.preprocess_text(title, abstract)
        
        # Tokenize
        inputs = self.tokenizer(
            processed_text,
            return_tensors="pt",
            truncation=True,
            max_length=self.model_config['max_length'],
            padding=True
        )
        
        # Get model predictions
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = torch.softmax(outputs.logits, dim=1)[0]
        
        # Get top categories that sum to 95% probability
        top_categories = self.get_top_categories(predictions)
        
        # Return predictions
        return {
            'top_categories': top_categories,
            'model_used': self.model_type,
            'input_type': 'title_and_abstract' if abstract else 'title_only'
        }
    
    def train_on_arxiv(self, train_texts, train_labels, validation_texts=None, validation_labels=None, 
                       epochs=3, batch_size=16, learning_rate=2e-5):
        """
        Function to fine-tune the model on ArXiv data
        
        Args:
            train_texts (list): List of paper texts (title + abstract)
            train_labels (list): List of corresponding ArXiv categories
            validation_texts (list, optional): Validation texts
            validation_labels (list, optional): Validation labels
            epochs (int): Number of training epochs
            batch_size (int): Training batch size
            learning_rate (float): Learning rate for training
        """
        from transformers import TrainingArguments, Trainer
        import datasets
        
        # Prepare datasets
        train_encodings = self.tokenizer(
            train_texts,
            truncation=True,
            padding=True,
            max_length=self.model_config['max_length']
        )
        
        # Convert labels to ids
        train_label_ids = [self.categories.index(label) for label in train_labels]
        
        # Create training dataset
        train_dataset = datasets.Dataset.from_dict({
            'input_ids': train_encodings['input_ids'],
            'attention_mask': train_encodings['attention_mask'],
            'labels': train_label_ids
        })
        
        # Create validation dataset if provided
        if validation_texts and validation_labels:
            val_encodings = self.tokenizer(
                validation_texts,
                truncation=True,
                padding=True,
                max_length=self.model_config['max_length']
            )
            val_label_ids = [self.categories.index(label) for label in validation_labels]
            validation_dataset = datasets.Dataset.from_dict({
                'input_ids': val_encodings['input_ids'],
                'attention_mask': val_encodings['attention_mask'],
                'labels': val_label_ids
            })
        else:
            validation_dataset = None
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=f"./results_{self.model_type}",
            num_train_epochs=epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir=f"./logs_{self.model_type}",
            logging_steps=10,
            learning_rate=learning_rate,
            evaluation_strategy="epoch" if validation_dataset else "no",
            save_strategy="epoch",
            load_best_model_at_end=True if validation_dataset else False,
        )
        
        # Initialize trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=validation_dataset,
        )
        
        # Train the model
        trainer.train()
        
        # Save the fine-tuned model
        save_dir = f"./fine_tuned_{self.model_type}"
        self.model.save_pretrained(save_dir)
        self.tokenizer.save_pretrained(save_dir)
        print(f"Model saved to {save_dir}")
    
    @classmethod
    def load_fine_tuned(cls, model_type, model_path):
        """
        Load a fine-tuned model from disk
        
        Args:
            model_type (str): The type of model that was fine-tuned
            model_path (str): Path to the saved model
        """
        classifier = cls(model_type)
        classifier.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        classifier.tokenizer = AutoTokenizer.from_pretrained(model_path)
        return classifier