import gradio as gr import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np from typing import List, Dict, Tuple import json import os from datetime import datetime class GRPOTrainer: def __init__(self): self.model = None self.ref_model = None self.tokenizer = None self.optimizer = None self.training_history = [] def load_model(self, model_name: str) -> str: """Load the model and tokenizer""" try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) self.ref_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) # Set padding token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Freeze reference model for param in self.ref_model.parameters(): param.requires_grad = False return f"✅ Successfully loaded model: {model_name}" except Exception as e: return f"❌ Error loading model: {str(e)}" def compute_rewards(self, prompts: List[str], responses: List[str]) -> torch.Tensor: """Compute rewards for responses (simplified reward function)""" rewards = [] for response in responses: # Simple reward based on response length and diversity length_reward = min(len(response.split()) / 50, 1.0) unique_words = len(set(response.lower().split())) diversity_reward = min(unique_words / 20, 1.0) reward = (length_reward + diversity_reward) / 2 rewards.append(reward) return torch.tensor(rewards) def compute_kl_penalty(self, logits: torch.Tensor, ref_logits: torch.Tensor) -> torch.Tensor: """Compute KL divergence penalty""" probs = F.softmax(logits, dim=-1) ref_probs = F.softmax(ref_logits, dim=-1) kl = (probs * (probs / ref_probs).log()).sum(-1) return kl.mean() def grpo_step(self, prompts: List[str], beta: float = 0.1) -> Dict: """Perform one GRPO training step""" if not self.model or not self.tokenizer: return {"error": "Model not loaded"} # Tokenize prompts inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True) # Generate responses with torch.no_grad(): outputs = self.model.generate( inputs.input_ids, max_length=inputs.input_ids.shape[1] + 50, do_sample=True, temperature=0.8, pad_token_id=self.tokenizer.pad_token_id ) # Decode responses responses = [] for output in outputs: response = self.tokenizer.decode(output[inputs.input_ids.shape[1]:], skip_special_tokens=True) responses.append(response) # Compute rewards rewards = self.compute_rewards(prompts, responses) # Forward pass through both models self.model.train() model_outputs = self.model(inputs.input_ids) ref_outputs = self.ref_model(inputs.input_ids) # Compute KL penalty kl_penalty = self.compute_kl_penalty(model_outputs.logits, ref_outputs.logits) # Compute loss (simplified GRPO loss) loss = -rewards.mean() + beta * kl_penalty # Backward pass if self.optimizer: self.optimizer.zero_grad() loss.backward() self.optimizer.step() return { "loss": loss.item(), "reward": rewards.mean().item(), "kl_penalty": kl_penalty.item(), "responses": responses } def train(self, prompts: List[str], num_steps: int, lr: float, beta: float) -> str: """Run GRPO training""" if not self.model: return "❌ Please load a model first" # Initialize optimizer self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr) results = [] for step in range(num_steps): step_result = self.grpo_step(prompts, beta) if "error" in step_result: return f"❌ Error: {step_result['error']}" result_str = f"Step {step + 1}/{num_steps} - Loss: {step_result['loss']:.4f}, Reward: {step_result['reward']:.4f}, KL: {step_result['kl_penalty']:.4f}" results.append(result_str) # Store training history self.training_history.append({ "step": step + 1, "loss": step_result['loss'], "reward": step_result['reward'], "kl_penalty": step_result['kl_penalty'] }) return "\n".join(results) def generate_response(self, prompt: str, max_length: int = 100, temperature: float = 0.8) -> str: """Generate a response using the trained model""" if not self.model or not self.tokenizer: return "❌ Please load a model first" inputs = self.tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = self.model.generate( inputs.input_ids, max_length=inputs.input_ids.shape[1] + max_length, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.pad_token_id ) response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response def save_model(self, save_path: str) -> str: """Save the trained model""" if not self.model: return "❌ No model to save" try: self.model.save_pretrained(save_path) self.tokenizer.save_pretrained(save_path) # Save training history with open(os.path.join(save_path, "training_history.json"), "w") as f: json.dump(self.training_history, f) return f"✅ Model saved to {save_path}" except Exception as e: return f"❌ Error saving model: {str(e)}" # Initialize trainer trainer = GRPOTrainer() # Gradio interface def load_model_interface(model_name): return trainer.load_model(model_name) def train_interface(prompts_text, num_steps, learning_rate, beta): prompts = [p.strip() for p in prompts_text.split("\n") if p.strip()] if not prompts: return "❌ Please provide at least one prompt" return trainer.train(prompts, int(num_steps), float(learning_rate), float(beta)) def generate_interface(prompt, max_length, temperature): return trainer.generate_response(prompt, int(max_length), float(temperature)) def save_model_interface(save_path): return trainer.save_model(save_path) def get_training_history(): if not trainer.training_history: return "No training history available" history_str = "Training History:\n" history_str += "-" * 50 + "\n" for entry in trainer.training_history[-10:]: # Show last 10 entries history_str += f"Step {entry['step']}: Loss={entry['loss']:.4f}, Reward={entry['reward']:.4f}, KL={entry['kl_penalty']:.4f}\n" return history_str # Create Gradio interface with gr.Blocks(title="GRPO Model Training") as app: gr.Markdown("# 🚀 GRPO (Group Relative Policy Optimization) Training App") gr.Markdown("Train language models using GRPO technique with this simple interface") with gr.Tab("🔧 Model Setup"): with gr.Row(): model_input = gr.Textbox( label="Model Name", value="Writer/Palmyra-56B-Instruct", placeholder="Enter HuggingFace model name (e.g., Palmyra, Qwen, Llama)" ) load_btn = gr.Button("Load Model", variant="primary") model_status = gr.Textbox(label="Status", lines=2) load_btn.click(load_model_interface, inputs=model_input, outputs=model_status) with gr.Tab("🎯 Training"): with gr.Row(): with gr.Column(): prompts_input = gr.Textbox( label="Training Prompts (one per line)", lines=5, value="Tell me about artificial intelligence\nExplain quantum computing\nWhat is machine learning?", placeholder="Enter your prompts here..." ) with gr.Column(): num_steps_input = gr.Slider( label="Number of Training Steps", minimum=1, maximum=100, value=10, step=1 ) lr_input = gr.Number( label="Learning Rate", value=1e-5, step=1e-6 ) beta_input = gr.Number( label="KL Penalty Weight (β)", value=0.1, step=0.01 ) train_btn = gr.Button("Start Training", variant="primary") training_output = gr.Textbox(label="Training Progress", lines=10) train_btn.click( train_interface, inputs=[prompts_input, num_steps_input, lr_input, beta_input], outputs=training_output ) with gr.Tab("💬 Generation"): with gr.Row(): with gr.Column(): gen_prompt = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", value="Tell me about" ) max_length = gr.Slider( label="Max Length", minimum=10, maximum=500, value=100, step=10 ) temp_slider = gr.Slider( label="Temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1 ) with gr.Column(): gen_btn = gr.Button("Generate", variant="primary") gen_output = gr.Textbox(label="Generated Response", lines=10) gen_btn.click( generate_interface, inputs=[gen_prompt, max_length, temp_slider], outputs=gen_output ) with gr.Tab("💾 Save Model"): save_path_input = gr.Textbox( label="Save Path", value="./grpo_trained_model", placeholder="Enter path to save the model" ) save_btn = gr.Button("Save Model", variant="primary") save_status = gr.Textbox(label="Save Status") save_btn.click(save_model_interface, inputs=save_path_input, outputs=save_status) with gr.Tab("📊 Training History"): history_btn = gr.Button("Refresh History", variant="secondary") history_output = gr.Textbox(label="Training History", lines=15) history_btn.click(get_training_history, outputs=history_output) gr.Markdown(""" ## 📝 Instructions: 1. **Load Model**: Start by loading a pre-trained model from HuggingFace 2. **Training**: Add your prompts and configure training parameters 3. **Generation**: Test your trained model with custom prompts 4. **Save**: Save your fine-tuned model for later use ## ⚠️ Note: - This is a simplified GRPO implementation for demonstration - For production use, consider more sophisticated reward functions - GPU recommended for larger models """) # Launch the app if __name__ == "__main__": app.launch(share=True)