|
|
import os
|
|
|
import argparse
|
|
|
import torch
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
from torch.optim import AdamW
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
from tqdm import tqdm
|
|
|
from src.modeling_openpeer import OpenPeerLLM
|
|
|
from src.configuration_openpeer import OpenPeerConfig
|
|
|
from src.tokenization_openpeer import OpenPeerTokenizer
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
|
def __init__(self, texts, tokenizer, max_length=1024):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.texts = texts
|
|
|
self.max_length = max_length
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.texts)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
text = self.texts[idx]
|
|
|
encoded = self.tokenizer(text,
|
|
|
truncation=True,
|
|
|
max_length=self.max_length)
|
|
|
|
|
|
input_ids = encoded["input_ids"]
|
|
|
attention_mask = encoded["attention_mask"]
|
|
|
|
|
|
|
|
|
labels = input_ids[1:] + [self.tokenizer.eos_token_id]
|
|
|
|
|
|
return {
|
|
|
"input_ids": torch.tensor(input_ids),
|
|
|
"attention_mask": torch.tensor(attention_mask),
|
|
|
"labels": torch.tensor(labels)
|
|
|
}
|
|
|
|
|
|
def collate_fn(batch):
|
|
|
input_ids = [item["input_ids"] for item in batch]
|
|
|
attention_mask = [item["attention_mask"] for item in batch]
|
|
|
labels = [item["labels"] for item in batch]
|
|
|
|
|
|
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
|
attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
|
|
|
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
|
|
|
|
|
return {
|
|
|
"input_ids": input_ids,
|
|
|
"attention_mask": attention_mask,
|
|
|
"labels": labels
|
|
|
}
|
|
|
|
|
|
def train(
|
|
|
model,
|
|
|
train_dataloader,
|
|
|
optimizer,
|
|
|
scheduler,
|
|
|
num_epochs,
|
|
|
device,
|
|
|
save_path,
|
|
|
log_interval=100
|
|
|
):
|
|
|
model.train()
|
|
|
total_steps = 0
|
|
|
best_loss = float('inf')
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
|
|
progress_bar = tqdm(train_dataloader, desc="Training")
|
|
|
epoch_loss = 0
|
|
|
|
|
|
for batch_idx, batch in enumerate(progress_bar):
|
|
|
|
|
|
input_ids = batch["input_ids"].to(device)
|
|
|
attention_mask = batch["attention_mask"].to(device)
|
|
|
labels = batch["labels"].to(device)
|
|
|
|
|
|
|
|
|
outputs = model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
labels=labels
|
|
|
)
|
|
|
|
|
|
loss = outputs["loss"]
|
|
|
epoch_loss += loss.item()
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
|
optimizer.step()
|
|
|
scheduler.step()
|
|
|
|
|
|
total_steps += 1
|
|
|
|
|
|
|
|
|
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
|
|
|
|
|
|
|
if loss.item() < best_loss:
|
|
|
best_loss = loss.item()
|
|
|
torch.save({
|
|
|
"epoch": epoch,
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
"loss": best_loss,
|
|
|
}, f"{save_path}/best_model.pt")
|
|
|
|
|
|
|
|
|
avg_epoch_loss = epoch_loss / len(train_dataloader)
|
|
|
print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")
|
|
|
|
|
|
checkpoint = {
|
|
|
"epoch": epoch,
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
"loss": avg_epoch_loss,
|
|
|
}
|
|
|
torch.save(checkpoint, f"{save_path}/checkpoint_epoch_{epoch+1}.pt")
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--train_data", type=str, required=True, help="Path to training data file")
|
|
|
parser.add_argument("--save_path", type=str, required=True, help="Directory to save model checkpoints")
|
|
|
parser.add_argument("--load_checkpoint", type=str, help="Path to model checkpoint to continue training")
|
|
|
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
|
|
|
parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
|
|
|
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
|
|
|
parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
os.makedirs(args.save_path, exist_ok=True)
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
config = OpenPeerConfig()
|
|
|
model = OpenPeerLLM(config).to(device)
|
|
|
tokenizer = OpenPeerTokenizer()
|
|
|
|
|
|
|
|
|
start_epoch = 0
|
|
|
if args.load_checkpoint and os.path.exists(args.load_checkpoint):
|
|
|
print(f"Loading checkpoint: {args.load_checkpoint}")
|
|
|
checkpoint = torch.load(args.load_checkpoint, map_location=device)
|
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
start_epoch = checkpoint["epoch"] + 1
|
|
|
print(f"Resuming from epoch {start_epoch}")
|
|
|
|
|
|
|
|
|
print("Loading training data...")
|
|
|
with open(args.train_data, 'r', encoding='utf-8') as f:
|
|
|
texts = [line.strip() for line in f.readlines() if line.strip()]
|
|
|
|
|
|
|
|
|
print("Creating dataset...")
|
|
|
dataset = TextDataset(texts, tokenizer, max_length=args.max_length)
|
|
|
train_dataloader = DataLoader(
|
|
|
dataset,
|
|
|
batch_size=args.batch_size,
|
|
|
shuffle=True,
|
|
|
collate_fn=collate_fn,
|
|
|
num_workers=4
|
|
|
)
|
|
|
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * args.num_epochs)
|
|
|
|
|
|
|
|
|
if args.load_checkpoint and os.path.exists(args.load_checkpoint):
|
|
|
checkpoint = torch.load(args.load_checkpoint, map_location=device)
|
|
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
|
|
|
|
|
|
|
print("Starting training...")
|
|
|
train(
|
|
|
model=model,
|
|
|
train_dataloader=train_dataloader,
|
|
|
optimizer=optimizer,
|
|
|
scheduler=scheduler,
|
|
|
num_epochs=args.num_epochs,
|
|
|
device=device,
|
|
|
save_path=args.save_path,
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |