|
from google.colab import drive
|
|
import os
|
|
import glob
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import pdfplumber
|
|
import random
|
|
import math
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer
|
|
from torch.utils.data import DataLoader, Dataset, random_split
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
from huggingface_hub import login
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import logging
|
|
from typing import Tuple, List, Dict
|
|
|
|
|
|
class Config:
|
|
|
|
D_MODEL = 512
|
|
NHEAD = 8
|
|
ENC_LAYERS = 6
|
|
DEC_LAYERS = 6
|
|
DIM_FEEDFORWARD = 2048
|
|
DROPOUT = 0.1
|
|
|
|
|
|
BATCH_SIZE = 4
|
|
GRAD_ACCUM_STEPS = 2
|
|
LR = 1e-4
|
|
EPOCHS = 20
|
|
MAX_GRAD_NORM = 1.0
|
|
|
|
|
|
INPUT_MAX_LEN = 512
|
|
SUMMARY_MAX_LEN = 128
|
|
CHUNK_SIZE = 512
|
|
|
|
|
|
CHECKPOINT_DIR = "/content/drive/MyDrive/legal_summarization_checkpoints_6"
|
|
LOG_DIR = os.path.join(CHECKPOINT_DIR, "logs")
|
|
|
|
@classmethod
|
|
def setup_paths(cls):
|
|
os.makedirs(cls.CHECKPOINT_DIR, exist_ok=True)
|
|
os.makedirs(cls.LOG_DIR, exist_ok=True)
|
|
|
|
|
|
Config.setup_paths()
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(os.path.join(Config.LOG_DIR, 'training.log')),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(_name_)
|
|
|
|
|
|
login(token="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")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
drive.mount('/content/drive', force_remount=True)
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
|
vocab_size = tokenizer.vocab_size
|
|
|
|
|
|
writer = SummaryWriter(Config.LOG_DIR)
|
|
|
|
def clean_text(text: str) -> str:
|
|
"""Basic text cleaning"""
|
|
text = ' '.join(text.split())
|
|
return text.strip()
|
|
|
|
def extract_text_from_pdf(pdf_path: str, chunk_size: int = Config.CHUNK_SIZE) -> List[str]:
|
|
"""Extract and chunk text from PDF with error handling"""
|
|
text = ''
|
|
try:
|
|
with pdfplumber.open(pdf_path) as pdf:
|
|
for page in pdf.pages:
|
|
page_text = page.extract_text() or ''
|
|
text += page_text + ' '
|
|
except Exception as e:
|
|
logger.warning(f"Error processing {pdf_path}: {str(e)}")
|
|
return []
|
|
|
|
text = clean_text(text)
|
|
words = text.split()
|
|
return [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)] if words else []
|
|
|
|
def load_texts_from_folder(folder_path: str, chunk_size: int = Config.CHUNK_SIZE) -> List[str]:
|
|
"""Load and chunk texts from folder with multiple file types"""
|
|
texts = []
|
|
for fname in sorted(os.listdir(folder_path)):
|
|
path = os.path.join(folder_path, fname)
|
|
try:
|
|
if path.endswith('.pdf'):
|
|
chunks = extract_text_from_pdf(path, chunk_size)
|
|
if chunks:
|
|
texts.extend(chunks)
|
|
else:
|
|
with open(path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
content = clean_text(f.read())
|
|
if content:
|
|
texts.extend([content[i:i+chunk_size] for i in range(0, len(content), chunk_size)])
|
|
except Exception as e:
|
|
logger.warning(f"Error loading {path}: {str(e)}")
|
|
continue
|
|
return texts
|
|
|
|
class LegalDataset(Dataset):
|
|
def _init_(self, texts: List[str], summaries: List[str], tokenizer: AutoTokenizer,
|
|
input_max_len: int = Config.INPUT_MAX_LEN,
|
|
summary_max_len: int = Config.SUMMARY_MAX_LEN):
|
|
assert len(texts) == len(summaries), "Texts and summaries must be same length"
|
|
self.texts = texts
|
|
self.summaries = summaries
|
|
self.tokenizer = tokenizer
|
|
self.input_max_len = input_max_len
|
|
self.summary_max_len = summary_max_len
|
|
|
|
def _len_(self):
|
|
return len(self.texts)
|
|
|
|
def _getitem_(self, idx):
|
|
src = self.texts[idx]
|
|
tgt = self.summaries[idx]
|
|
|
|
enc = self.tokenizer(
|
|
src,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=self.input_max_len,
|
|
return_tensors='pt'
|
|
)
|
|
|
|
dec = self.tokenizer(
|
|
tgt,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=self.summary_max_len,
|
|
return_tensors='pt'
|
|
)
|
|
|
|
return {
|
|
'input_ids': enc.input_ids.squeeze(),
|
|
'attention_mask': enc.attention_mask.squeeze(),
|
|
'labels': dec.input_ids.squeeze()
|
|
}
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def _init_(self, d_model: int, dropout: float = 0.1, max_len: int = 1024):
|
|
super()._init_()
|
|
self.dropout = nn.Dropout(dropout)
|
|
pe = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
div = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
|
pe[:, 0::2] = torch.sin(position * div)
|
|
pe[:, 1::2] = torch.cos(position * div)
|
|
self.register_buffer('pe', pe.unsqueeze(0))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = x + self.pe[:, :x.size(1)]
|
|
return self.dropout(x)
|
|
|
|
class CustomTransformer(nn.Module):
|
|
def _init_(self, vocab_size: int, d_model: int = Config.D_MODEL, nhead: int = Config.NHEAD,
|
|
enc_layers: int = Config.ENC_LAYERS, dec_layers: int = Config.DEC_LAYERS,
|
|
dim_feedforward: int = Config.DIM_FEEDFORWARD, dropout: float = Config.DROPOUT):
|
|
super()._init_()
|
|
self.embed = nn.Embedding(vocab_size, d_model)
|
|
self.pos_enc = PositionalEncoding(d_model, dropout)
|
|
self.transformer = nn.Transformer(
|
|
d_model=d_model,
|
|
nhead=nhead,
|
|
num_encoder_layers=enc_layers,
|
|
num_decoder_layers=dec_layers,
|
|
dim_feedforward=dim_feedforward,
|
|
dropout=dropout,
|
|
batch_first=True
|
|
)
|
|
self.fc = nn.Linear(d_model, vocab_size)
|
|
|
|
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
|
|
def forward(self, src_ids: torch.Tensor, tgt_ids: torch.Tensor,
|
|
src_key_padding_mask: torch.Tensor = None,
|
|
tgt_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
|
|
|
|
|
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_ids.size(1)).to(tgt_ids.device)
|
|
|
|
src = self.pos_enc(self.embed(src_ids))
|
|
tgt = self.pos_enc(self.embed(tgt_ids))
|
|
|
|
out = self.transformer(
|
|
src, tgt,
|
|
tgt_mask=tgt_mask,
|
|
src_key_padding_mask=src_key_padding_mask,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
memory_key_padding_mask=src_key_padding_mask
|
|
)
|
|
return self.fc(out)
|
|
|
|
def create_masks(input_ids: torch.Tensor, decoder_input: torch.Tensor, pad_token_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Create padding masks for transformer"""
|
|
src_pad_mask = (input_ids == pad_token_id)
|
|
tgt_pad_mask = (decoder_input == pad_token_id)
|
|
return src_pad_mask, tgt_pad_mask
|
|
|
|
def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader,
|
|
optimizer: optim.Optimizer, criterion: nn.Module, device: torch.device,
|
|
epochs: int = Config.EPOCHS, grad_accum_steps: int = Config.GRAD_ACCUM_STEPS):
|
|
|
|
model.to(device)
|
|
scaler = GradScaler()
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
|
|
best_val_loss = float('inf')
|
|
early_stop_counter = 0
|
|
|
|
for epoch in range(1, epochs + 1):
|
|
model.train()
|
|
train_loss = 0
|
|
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
|
|
|
|
for step, batch in enumerate(progress_bar, 1):
|
|
input_ids = batch['input_ids'].to(device)
|
|
attn_mask = batch['attention_mask'].to(device)
|
|
labels = batch['labels'].to(device)
|
|
|
|
|
|
decoder_input = torch.cat([
|
|
torch.full((labels.size(0), 1), tokenizer.pad_token_id, dtype=torch.long, device=device),
|
|
labels[:, :-1]
|
|
], dim=1)
|
|
|
|
|
|
src_pad_mask, tgt_pad_mask = create_masks(input_ids, decoder_input, tokenizer.pad_token_id)
|
|
|
|
with autocast():
|
|
outputs = model(
|
|
input_ids,
|
|
decoder_input,
|
|
src_key_padding_mask=src_pad_mask,
|
|
tgt_key_padding_mask=tgt_pad_mask
|
|
)
|
|
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
|
|
loss = loss / grad_accum_steps
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
if step % grad_accum_steps == 0:
|
|
scaler.unscale_(optimizer)
|
|
nn.utils.clip_grad_norm_(model.parameters(), Config.MAX_GRAD_NORM)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
train_loss += loss.item() * grad_accum_steps
|
|
progress_bar.set_postfix({'train_loss': f"{loss.item():.4f}"})
|
|
|
|
avg_train_loss = train_loss / len(train_loader)
|
|
writer.add_scalar('Loss/train', avg_train_loss, epoch)
|
|
logger.info(f"Epoch {epoch} Train Loss: {avg_train_loss:.4f}")
|
|
|
|
|
|
model.eval()
|
|
val_loss = 0
|
|
with torch.no_grad():
|
|
for batch in tqdm(val_loader, desc="Validating"):
|
|
input_ids = batch['input_ids'].to(device)
|
|
labels = batch['labels'].to(device)
|
|
decoder_input = torch.cat([
|
|
torch.full((labels.size(0), 1), tokenizer.pad_token_id, dtype=torch.long, device=device),
|
|
labels[:, :-1]
|
|
], dim=1)
|
|
|
|
src_pad_mask, tgt_pad_mask = create_masks(input_ids, decoder_input, tokenizer.pad_token_id)
|
|
|
|
with autocast():
|
|
outputs = model(
|
|
input_ids,
|
|
decoder_input,
|
|
src_key_padding_mask=src_pad_mask,
|
|
tgt_key_padding_mask=tgt_pad_mask
|
|
)
|
|
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
|
|
val_loss += loss.item()
|
|
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
writer.add_scalar('Loss/val', avg_val_loss, epoch)
|
|
logger.info(f"Epoch {epoch} Val Loss: {avg_val_loss:.4f}")
|
|
|
|
|
|
scheduler.step(avg_val_loss)
|
|
|
|
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
early_stop_counter = 0
|
|
|
|
ckpt_path = os.path.join(Config.CHECKPOINT_DIR, f"transformer_best.pt")
|
|
torch.save(model.state_dict(), ckpt_path)
|
|
logger.info(f"New best model saved with val loss: {best_val_loss:.4f}")
|
|
else:
|
|
early_stop_counter += 1
|
|
if early_stop_counter >= 3:
|
|
logger.info("Early stopping triggered")
|
|
break
|
|
|
|
|
|
ckpt_path = os.path.join(Config.CHECKPOINT_DIR, f"transformer_epoch_{epoch}.pt")
|
|
torch.save(model.state_dict(), ckpt_path)
|
|
|
|
|
|
manage_checkpoints()
|
|
|
|
def manage_checkpoints():
|
|
"""Keep only the 2 most recent checkpoints"""
|
|
files = sorted(glob.glob(os.path.join(Config.CHECKPOINT_DIR, "transformer_epoch_*.pt")), key=os.path.getctime)
|
|
if len(files) > 2:
|
|
for old in files[:-2]:
|
|
os.remove(old)
|
|
logger.info(f"Removed old checkpoint: {old}")
|
|
|
|
if _name_ == "_main_":
|
|
try:
|
|
logger.info("Starting training process")
|
|
|
|
|
|
logger.info("Loading texts and summaries")
|
|
texts = load_texts_from_folder("/content/drive/MyDrive/dataset/IN-Abs/train-data/judgement")
|
|
sums = load_texts_from_folder("/content/drive/MyDrive/dataset/IN-Abs/train-data/summary")
|
|
|
|
if not texts or not sums:
|
|
raise ValueError("No data loaded - check your input paths and files")
|
|
|
|
logger.info(f"Loaded {len(texts)} text chunks and {len(sums)} summary chunks")
|
|
|
|
|
|
full_ds = LegalDataset(texts, sums, tokenizer)
|
|
|
|
|
|
val_size = int(0.1 * len(full_ds))
|
|
train_size = len(full_ds) - val_size
|
|
train_ds, val_ds = random_split(full_ds, [train_size, val_size])
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True)
|
|
val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE)
|
|
|
|
|
|
model = CustomTransformer(vocab_size)
|
|
optimizer = optim.Adam(model.parameters(), lr=Config.LR)
|
|
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
|
|
|
|
|
train_model(model, train_loader, val_loader, optimizer, criterion, device)
|
|
|
|
logger.info("Training completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
|
raise |