saathwik commited on
Commit
18c076f
·
verified ·
1 Parent(s): 12f07e7

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +11 -0
  2. model.py +378 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["YourCustomModel"],
3
+ "model_type": "encoder-decoder",
4
+ "hidden_size": 512,
5
+ "num_attention_heads": 8,
6
+ "num_hidden_layers": 6,
7
+ "vocab_size": 32128,
8
+ "pad_token_id": 0,
9
+ "eos_token_id": 1,
10
+ "bos_token_id": 0
11
+ }
model.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.colab import drive
2
+ import os
3
+ import glob
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import pdfplumber
8
+ import random
9
+ import math
10
+ from tqdm import tqdm
11
+ from transformers import AutoTokenizer
12
+ from torch.utils.data import DataLoader, Dataset, random_split
13
+ from torch.cuda.amp import autocast, GradScaler # Fixed import
14
+ from huggingface_hub import login
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ import logging
17
+ from typing import Tuple, List, Dict
18
+
19
+ # Configuration
20
+ class Config:
21
+ # Model
22
+ D_MODEL = 512
23
+ NHEAD = 8
24
+ ENC_LAYERS = 6
25
+ DEC_LAYERS = 6
26
+ DIM_FEEDFORWARD = 2048
27
+ DROPOUT = 0.1
28
+
29
+ # Training
30
+ BATCH_SIZE = 4
31
+ GRAD_ACCUM_STEPS = 2
32
+ LR = 1e-4
33
+ EPOCHS = 20
34
+ MAX_GRAD_NORM = 1.0
35
+
36
+ # Data
37
+ INPUT_MAX_LEN = 512
38
+ SUMMARY_MAX_LEN = 128
39
+ CHUNK_SIZE = 512
40
+
41
+ # Paths
42
+ CHECKPOINT_DIR = "/content/drive/MyDrive/legal_summarization_checkpoints_6"
43
+ LOG_DIR = os.path.join(CHECKPOINT_DIR, "logs")
44
+
45
+ @classmethod
46
+ def setup_paths(cls):
47
+ os.makedirs(cls.CHECKPOINT_DIR, exist_ok=True)
48
+ os.makedirs(cls.LOG_DIR, exist_ok=True)
49
+
50
+ # Initialize config
51
+ Config.setup_paths()
52
+
53
+ # Setup logging
54
+ logging.basicConfig(
55
+ level=logging.INFO,
56
+ format='%(asctime)s - %(levelname)s - %(message)s',
57
+ handlers=[
58
+ logging.FileHandler(os.path.join(Config.LOG_DIR, 'training.log')),
59
+ logging.StreamHandler()
60
+ ]
61
+ )
62
+ logger = logging.getLogger(_name_)
63
+
64
+ # Authenticate Hugging Face
65
+ login(token="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")
66
+
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+
69
+ # Mount Google Drive
70
+ drive.mount('/content/drive', force_remount=True)
71
+
72
+ # Tokenizer
73
+ tokenizer = AutoTokenizer.from_pretrained("t5-small")
74
+ vocab_size = tokenizer.vocab_size
75
+
76
+ # TensorBoard
77
+ writer = SummaryWriter(Config.LOG_DIR)
78
+
79
+ def clean_text(text: str) -> str:
80
+ """Basic text cleaning"""
81
+ text = ' '.join(text.split()) # Remove extra whitespace
82
+ return text.strip()
83
+
84
+ def extract_text_from_pdf(pdf_path: str, chunk_size: int = Config.CHUNK_SIZE) -> List[str]:
85
+ """Extract and chunk text from PDF with error handling"""
86
+ text = ''
87
+ try:
88
+ with pdfplumber.open(pdf_path) as pdf:
89
+ for page in pdf.pages:
90
+ page_text = page.extract_text() or ''
91
+ text += page_text + ' '
92
+ except Exception as e:
93
+ logger.warning(f"Error processing {pdf_path}: {str(e)}")
94
+ return []
95
+
96
+ text = clean_text(text)
97
+ words = text.split()
98
+ return [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)] if words else []
99
+
100
+ def load_texts_from_folder(folder_path: str, chunk_size: int = Config.CHUNK_SIZE) -> List[str]:
101
+ """Load and chunk texts from folder with multiple file types"""
102
+ texts = []
103
+ for fname in sorted(os.listdir(folder_path)):
104
+ path = os.path.join(folder_path, fname)
105
+ try:
106
+ if path.endswith('.pdf'):
107
+ chunks = extract_text_from_pdf(path, chunk_size)
108
+ if chunks:
109
+ texts.extend(chunks)
110
+ else:
111
+ with open(path, 'r', encoding='utf-8', errors='ignore') as f:
112
+ content = clean_text(f.read())
113
+ if content:
114
+ texts.extend([content[i:i+chunk_size] for i in range(0, len(content), chunk_size)])
115
+ except Exception as e:
116
+ logger.warning(f"Error loading {path}: {str(e)}")
117
+ continue
118
+ return texts
119
+
120
+ class LegalDataset(Dataset):
121
+ def _init_(self, texts: List[str], summaries: List[str], tokenizer: AutoTokenizer,
122
+ input_max_len: int = Config.INPUT_MAX_LEN,
123
+ summary_max_len: int = Config.SUMMARY_MAX_LEN):
124
+ assert len(texts) == len(summaries), "Texts and summaries must be same length"
125
+ self.texts = texts
126
+ self.summaries = summaries
127
+ self.tokenizer = tokenizer
128
+ self.input_max_len = input_max_len
129
+ self.summary_max_len = summary_max_len
130
+
131
+ def _len_(self):
132
+ return len(self.texts)
133
+
134
+ def _getitem_(self, idx):
135
+ src = self.texts[idx]
136
+ tgt = self.summaries[idx]
137
+
138
+ enc = self.tokenizer(
139
+ src,
140
+ padding='max_length',
141
+ truncation=True,
142
+ max_length=self.input_max_len,
143
+ return_tensors='pt'
144
+ )
145
+
146
+ dec = self.tokenizer(
147
+ tgt,
148
+ padding='max_length',
149
+ truncation=True,
150
+ max_length=self.summary_max_len,
151
+ return_tensors='pt'
152
+ )
153
+
154
+ return {
155
+ 'input_ids': enc.input_ids.squeeze(),
156
+ 'attention_mask': enc.attention_mask.squeeze(),
157
+ 'labels': dec.input_ids.squeeze()
158
+ }
159
+
160
+ class PositionalEncoding(nn.Module):
161
+ def _init_(self, d_model: int, dropout: float = 0.1, max_len: int = 1024):
162
+ super()._init_()
163
+ self.dropout = nn.Dropout(dropout)
164
+ pe = torch.zeros(max_len, d_model)
165
+ position = torch.arange(0, max_len).unsqueeze(1).float()
166
+ div = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
167
+ pe[:, 0::2] = torch.sin(position * div)
168
+ pe[:, 1::2] = torch.cos(position * div)
169
+ self.register_buffer('pe', pe.unsqueeze(0))
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ x = x + self.pe[:, :x.size(1)]
173
+ return self.dropout(x)
174
+
175
+ class CustomTransformer(nn.Module):
176
+ def _init_(self, vocab_size: int, d_model: int = Config.D_MODEL, nhead: int = Config.NHEAD,
177
+ enc_layers: int = Config.ENC_LAYERS, dec_layers: int = Config.DEC_LAYERS,
178
+ dim_feedforward: int = Config.DIM_FEEDFORWARD, dropout: float = Config.DROPOUT):
179
+ super()._init_()
180
+ self.embed = nn.Embedding(vocab_size, d_model)
181
+ self.pos_enc = PositionalEncoding(d_model, dropout)
182
+ self.transformer = nn.Transformer(
183
+ d_model=d_model,
184
+ nhead=nhead,
185
+ num_encoder_layers=enc_layers,
186
+ num_decoder_layers=dec_layers,
187
+ dim_feedforward=dim_feedforward,
188
+ dropout=dropout,
189
+ batch_first=True
190
+ )
191
+ self.fc = nn.Linear(d_model, vocab_size)
192
+
193
+ # Initialize weights
194
+ self._init_weights()
195
+
196
+ def _init_weights(self):
197
+ for p in self.parameters():
198
+ if p.dim() > 1:
199
+ nn.init.xavier_uniform_(p)
200
+
201
+ def forward(self, src_ids: torch.Tensor, tgt_ids: torch.Tensor,
202
+ src_key_padding_mask: torch.Tensor = None,
203
+ tgt_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
204
+
205
+ # Create causal mask for decoder
206
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_ids.size(1)).to(tgt_ids.device)
207
+
208
+ src = self.pos_enc(self.embed(src_ids))
209
+ tgt = self.pos_enc(self.embed(tgt_ids))
210
+
211
+ out = self.transformer(
212
+ src, tgt,
213
+ tgt_mask=tgt_mask,
214
+ src_key_padding_mask=src_key_padding_mask,
215
+ tgt_key_padding_mask=tgt_key_padding_mask,
216
+ memory_key_padding_mask=src_key_padding_mask
217
+ )
218
+ return self.fc(out)
219
+
220
+ def create_masks(input_ids: torch.Tensor, decoder_input: torch.Tensor, pad_token_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
221
+ """Create padding masks for transformer"""
222
+ src_pad_mask = (input_ids == pad_token_id)
223
+ tgt_pad_mask = (decoder_input == pad_token_id)
224
+ return src_pad_mask, tgt_pad_mask
225
+
226
+ def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader,
227
+ optimizer: optim.Optimizer, criterion: nn.Module, device: torch.device,
228
+ epochs: int = Config.EPOCHS, grad_accum_steps: int = Config.GRAD_ACCUM_STEPS):
229
+
230
+ model.to(device)
231
+ scaler = GradScaler()
232
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
233
+ best_val_loss = float('inf')
234
+ early_stop_counter = 0
235
+
236
+ for epoch in range(1, epochs + 1):
237
+ model.train()
238
+ train_loss = 0
239
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
240
+
241
+ for step, batch in enumerate(progress_bar, 1):
242
+ input_ids = batch['input_ids'].to(device)
243
+ attn_mask = batch['attention_mask'].to(device)
244
+ labels = batch['labels'].to(device)
245
+
246
+ # Prepare decoder input with <pad> as start
247
+ decoder_input = torch.cat([
248
+ torch.full((labels.size(0), 1), tokenizer.pad_token_id, dtype=torch.long, device=device),
249
+ labels[:, :-1]
250
+ ], dim=1)
251
+
252
+ # Create masks
253
+ src_pad_mask, tgt_pad_mask = create_masks(input_ids, decoder_input, tokenizer.pad_token_id)
254
+
255
+ with autocast():
256
+ outputs = model(
257
+ input_ids,
258
+ decoder_input,
259
+ src_key_padding_mask=src_pad_mask,
260
+ tgt_key_padding_mask=tgt_pad_mask
261
+ )
262
+ loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
263
+ loss = loss / grad_accum_steps
264
+
265
+ scaler.scale(loss).backward()
266
+
267
+ if step % grad_accum_steps == 0:
268
+ scaler.unscale_(optimizer)
269
+ nn.utils.clip_grad_norm_(model.parameters(), Config.MAX_GRAD_NORM)
270
+ scaler.step(optimizer)
271
+ scaler.update()
272
+ optimizer.zero_grad()
273
+
274
+ train_loss += loss.item() * grad_accum_steps
275
+ progress_bar.set_postfix({'train_loss': f"{loss.item():.4f}"})
276
+
277
+ avg_train_loss = train_loss / len(train_loader)
278
+ writer.add_scalar('Loss/train', avg_train_loss, epoch)
279
+ logger.info(f"Epoch {epoch} Train Loss: {avg_train_loss:.4f}")
280
+
281
+ # Validation
282
+ model.eval()
283
+ val_loss = 0
284
+ with torch.no_grad():
285
+ for batch in tqdm(val_loader, desc="Validating"):
286
+ input_ids = batch['input_ids'].to(device)
287
+ labels = batch['labels'].to(device)
288
+ decoder_input = torch.cat([
289
+ torch.full((labels.size(0), 1), tokenizer.pad_token_id, dtype=torch.long, device=device),
290
+ labels[:, :-1]
291
+ ], dim=1)
292
+
293
+ src_pad_mask, tgt_pad_mask = create_masks(input_ids, decoder_input, tokenizer.pad_token_id)
294
+
295
+ with autocast():
296
+ outputs = model(
297
+ input_ids,
298
+ decoder_input,
299
+ src_key_padding_mask=src_pad_mask,
300
+ tgt_key_padding_mask=tgt_pad_mask
301
+ )
302
+ loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
303
+ val_loss += loss.item()
304
+
305
+ avg_val_loss = val_loss / len(val_loader)
306
+ writer.add_scalar('Loss/val', avg_val_loss, epoch)
307
+ logger.info(f"Epoch {epoch} Val Loss: {avg_val_loss:.4f}")
308
+
309
+ # Learning rate scheduling
310
+ scheduler.step(avg_val_loss)
311
+
312
+ # Early stopping & checkpointing
313
+ if avg_val_loss < best_val_loss:
314
+ best_val_loss = avg_val_loss
315
+ early_stop_counter = 0
316
+ # Save best model
317
+ ckpt_path = os.path.join(Config.CHECKPOINT_DIR, f"transformer_best.pt")
318
+ torch.save(model.state_dict(), ckpt_path)
319
+ logger.info(f"New best model saved with val loss: {best_val_loss:.4f}")
320
+ else:
321
+ early_stop_counter += 1
322
+ if early_stop_counter >= 3:
323
+ logger.info("Early stopping triggered")
324
+ break
325
+
326
+ # Save regular checkpoint
327
+ ckpt_path = os.path.join(Config.CHECKPOINT_DIR, f"transformer_epoch_{epoch}.pt")
328
+ torch.save(model.state_dict(), ckpt_path)
329
+
330
+ # Keep only latest 2 checkpoints
331
+ manage_checkpoints()
332
+
333
+ def manage_checkpoints():
334
+ """Keep only the 2 most recent checkpoints"""
335
+ files = sorted(glob.glob(os.path.join(Config.CHECKPOINT_DIR, "transformer_epoch_*.pt")), key=os.path.getctime)
336
+ if len(files) > 2:
337
+ for old in files[:-2]:
338
+ os.remove(old)
339
+ logger.info(f"Removed old checkpoint: {old}")
340
+
341
+ if _name_ == "_main_":
342
+ try:
343
+ logger.info("Starting training process")
344
+
345
+ # Load data
346
+ logger.info("Loading texts and summaries")
347
+ texts = load_texts_from_folder("/content/drive/MyDrive/dataset/IN-Abs/train-data/judgement")
348
+ sums = load_texts_from_folder("/content/drive/MyDrive/dataset/IN-Abs/train-data/summary")
349
+
350
+ if not texts or not sums:
351
+ raise ValueError("No data loaded - check your input paths and files")
352
+
353
+ logger.info(f"Loaded {len(texts)} text chunks and {len(sums)} summary chunks")
354
+
355
+ # Create dataset
356
+ full_ds = LegalDataset(texts, sums, tokenizer)
357
+
358
+ # Train/val split
359
+ val_size = int(0.1 * len(full_ds))
360
+ train_size = len(full_ds) - val_size
361
+ train_ds, val_ds = random_split(full_ds, [train_size, val_size])
362
+
363
+ train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True)
364
+ val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE)
365
+
366
+ # Initialize model
367
+ model = CustomTransformer(vocab_size)
368
+ optimizer = optim.Adam(model.parameters(), lr=Config.LR)
369
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
370
+
371
+ # Train
372
+ train_model(model, train_loader, val_loader, optimizer, criterion, device)
373
+
374
+ logger.info("Training completed successfully")
375
+
376
+ except Exception as e:
377
+ logger.error(f"Training failed: {str(e)}", exc_info=True)
378
+ raise
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f872ee941794f211166d52650ade64ab8848d5051392f6441a6c12699fea5529
3
+ size 242089493