OpenPeerLLM / train.py
Mentors4EDU's picture
Upload 27 files
d79115c verified
raw
history blame
7.18 kB
import os
import argparse
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from src.modeling_openpeer import OpenPeerLLM
from src.configuration_openpeer import OpenPeerConfig
from src.tokenization_openpeer import OpenPeerTokenizer
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=1024):
self.tokenizer = tokenizer
self.texts = texts
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
encoded = self.tokenizer(text,
truncation=True,
max_length=self.max_length)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
# Create labels for causal LM (shifted input_ids)
labels = input_ids[1:] + [self.tokenizer.eos_token_id]
return {
"input_ids": torch.tensor(input_ids),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels)
}
def collate_fn(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Pad sequences
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
labels = pad_sequence(labels, batch_first=True, padding_value=-100) # -100 is ignored in loss
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
def train(
model,
train_dataloader,
optimizer,
scheduler,
num_epochs,
device,
save_path,
log_interval=100
):
model.train()
total_steps = 0
best_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
progress_bar = tqdm(train_dataloader, desc="Training")
epoch_loss = 0
for batch_idx, batch in enumerate(progress_bar):
# Move batch to device
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
# Forward pass
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs["loss"]
epoch_loss += loss.item()
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_steps += 1
# Update progress bar
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
# Save best model
if loss.item() < best_loss:
best_loss = loss.item()
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": best_loss,
}, f"{save_path}/best_model.pt")
# Save checkpoint
avg_epoch_loss = epoch_loss / len(train_dataloader)
print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": avg_epoch_loss,
}
torch.save(checkpoint, f"{save_path}/checkpoint_epoch_{epoch+1}.pt")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--train_data", type=str, required=True, help="Path to training data file")
parser.add_argument("--save_path", type=str, required=True, help="Directory to save model checkpoints")
parser.add_argument("--load_checkpoint", type=str, help="Path to model checkpoint to continue training")
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length")
args = parser.parse_args()
# Create save directory if it doesn't exist
os.makedirs(args.save_path, exist_ok=True)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize model and tokenizer
config = OpenPeerConfig()
model = OpenPeerLLM(config).to(device)
tokenizer = OpenPeerTokenizer()
# Load checkpoint if specified
start_epoch = 0
if args.load_checkpoint and os.path.exists(args.load_checkpoint):
print(f"Loading checkpoint: {args.load_checkpoint}")
checkpoint = torch.load(args.load_checkpoint, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
start_epoch = checkpoint["epoch"] + 1
print(f"Resuming from epoch {start_epoch}")
# Load training data
print("Loading training data...")
with open(args.train_data, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f.readlines() if line.strip()]
# Create dataset and dataloader
print("Creating dataset...")
dataset = TextDataset(texts, tokenizer, max_length=args.max_length)
train_dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=4
)
# Initialize optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * args.num_epochs)
# Load optimizer state if resuming training
if args.load_checkpoint and os.path.exists(args.load_checkpoint):
checkpoint = torch.load(args.load_checkpoint, map_location=device)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Train the model
print("Starting training...")
train(
model=model,
train_dataloader=train_dataloader,
optimizer=optimizer,
scheduler=scheduler,
num_epochs=args.num_epochs,
device=device,
save_path=args.save_path,
)
if __name__ == "__main__":
main()