Spaces:
Runtime error
Runtime error
| """ | |
| PyTorch Lightning module for Multimodal Gemma training | |
| """ | |
| import torch | |
| import lightning as L | |
| from typing import Dict, Any, Optional, List | |
| from transformers import get_linear_schedule_with_warmup | |
| import logging | |
| from .multimodal_gemma import MultimodalGemma | |
| logger = logging.getLogger(__name__) | |
| class MultimodalGemmaLightning(L.LightningModule): | |
| """Lightning module for Multimodal Gemma training""" | |
| def __init__(self, config: Dict[str, Any]): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.config = config | |
| # Initialize model | |
| self.model = MultimodalGemma(config) | |
| # Training metrics tracking | |
| self.training_step_outputs = [] | |
| self.validation_step_outputs = [] | |
| # Setup automatic optimization | |
| self.automatic_optimization = True | |
| logger.info("MultimodalGemmaLightning initialized") | |
| def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """Forward pass""" | |
| return self.model( | |
| input_ids=batch["input_ids"], | |
| attention_mask=batch["attention_mask"], | |
| images=batch.get("images"), | |
| labels=batch["labels"] | |
| ) | |
| def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: | |
| """Training step""" | |
| outputs = self(batch) | |
| loss = outputs["loss"] | |
| # Log metrics | |
| self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) | |
| self.log("train/learning_rate", self.optimizers().param_groups[0]["lr"], on_step=True) | |
| # Store outputs for epoch end | |
| self.training_step_outputs.append(loss.detach()) | |
| return loss | |
| def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: | |
| """Validation step""" | |
| outputs = self(batch) | |
| loss = outputs["loss"] | |
| # Log metrics | |
| self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
| # Store outputs for epoch end | |
| self.validation_step_outputs.append(loss.detach()) | |
| return loss | |
| def on_train_epoch_end(self) -> None: | |
| """Called at the end of each training epoch""" | |
| if self.training_step_outputs: | |
| avg_loss = torch.stack(self.training_step_outputs).mean() | |
| self.log("train/epoch_loss", avg_loss, prog_bar=False, sync_dist=True) | |
| self.training_step_outputs.clear() | |
| def on_validation_epoch_end(self) -> None: | |
| """Called at the end of each validation epoch""" | |
| if self.validation_step_outputs: | |
| avg_loss = torch.stack(self.validation_step_outputs).mean() | |
| self.log("val/epoch_loss", avg_loss, prog_bar=False, sync_dist=True) | |
| self.validation_step_outputs.clear() | |
| def configure_optimizers(self): | |
| """Configure optimizer and scheduler""" | |
| # Collect trainable parameters with different learning rates | |
| param_groups = [] | |
| # Ensure learning rates are floats | |
| projector_lr = float(self.config["training"]["projector_lr"]) | |
| lora_lr = float(self.config["training"]["lora_lr"]) | |
| # Vision projector parameters | |
| vision_proj_params = list(self.model.vision_projector.parameters()) | |
| if vision_proj_params: | |
| param_groups.append({ | |
| "params": vision_proj_params, | |
| "lr": projector_lr, | |
| "name": "vision_projector" | |
| }) | |
| # Audio projector parameters (if enabled) | |
| if hasattr(self.model, 'audio_projector'): | |
| audio_proj_params = list(self.model.audio_projector.parameters()) | |
| if audio_proj_params: | |
| param_groups.append({ | |
| "params": audio_proj_params, | |
| "lr": projector_lr, | |
| "name": "audio_projector" | |
| }) | |
| # LoRA parameters from language model | |
| lora_params = [] | |
| for name, param in self.model.language_model.named_parameters(): | |
| if param.requires_grad: | |
| lora_params.append(param) | |
| if lora_params: | |
| param_groups.append({ | |
| "params": lora_params, | |
| "lr": lora_lr, | |
| "name": "lora_adapters" | |
| }) | |
| if not param_groups: | |
| raise ValueError("No trainable parameters found!") | |
| # Log parameter counts | |
| for group in param_groups: | |
| param_count = sum(p.numel() for p in group["params"]) | |
| logger.info(f"{group['name']}: {param_count:,} parameters, lr={group['lr']}") | |
| # Create optimizer | |
| optimizer_class = torch.optim.AdamW | |
| if self.config.get("optimization", {}).get("use_fused_adamw", False): | |
| try: | |
| optimizer_class = torch.optim.AdamW # Fused AdamW is default in recent PyTorch | |
| except AttributeError: | |
| logger.warning("Fused AdamW not available, using regular AdamW") | |
| optimizer = optimizer_class( | |
| param_groups, | |
| weight_decay=self.config["training"]["weight_decay"], | |
| eps=1e-8, | |
| betas=(0.9, 0.999) | |
| ) | |
| # Calculate total steps for scheduler | |
| if self.trainer.datamodule is not None: | |
| steps_per_epoch = len(self.trainer.datamodule.train_dataloader()) | |
| else: | |
| # Fallback estimation | |
| steps_per_epoch = self.config["training"].get("steps_per_epoch", 1000) | |
| max_epochs = self.config["training"]["max_epochs"] | |
| accumulate_grad_batches = self.config["training"].get("accumulate_grad_batches", 1) | |
| total_steps = (steps_per_epoch // accumulate_grad_batches) * max_epochs | |
| warmup_steps = int(total_steps * self.config["training"]["warmup_ratio"]) | |
| logger.info(f"Scheduler setup: {total_steps} total steps, {warmup_steps} warmup steps") | |
| # Create scheduler | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=warmup_steps, | |
| num_training_steps=total_steps | |
| ) | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": scheduler, | |
| "interval": "step", | |
| "frequency": 1, | |
| "name": "learning_rate" | |
| } | |
| } | |
| def lr_scheduler_step(self, scheduler, metric): | |
| """Custom learning rate scheduler step""" | |
| scheduler.step() | |
| def on_before_optimizer_step(self, optimizer): | |
| """Called before optimizer step""" | |
| # Log gradient norms | |
| if self.global_step % 100 == 0: | |
| grad_norm = 0.0 | |
| param_count = 0 | |
| for param_group in optimizer.param_groups: | |
| for param in param_group["params"]: | |
| if param.grad is not None: | |
| param_norm = param.grad.data.norm(2) | |
| grad_norm += param_norm.item() ** 2 | |
| param_count += 1 | |
| if param_count > 0: | |
| grad_norm = (grad_norm / param_count) ** 0.5 | |
| self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=False) | |
| def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
| """Called when saving checkpoint""" | |
| # Save additional model components | |
| checkpoint["model_config"] = self.config | |
| checkpoint["tokenizer_vocab_size"] = len(self.model.tokenizer) | |
| def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
| """Called when loading checkpoint""" | |
| # Restore model configuration if needed | |
| if "model_config" in checkpoint: | |
| logger.info("Loaded model configuration from checkpoint") | |
| def predict_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, Any]: | |
| """Prediction step for inference""" | |
| outputs = self.model.generate( | |
| input_ids=batch["input_ids"], | |
| attention_mask=batch["attention_mask"], | |
| images=batch.get("images"), | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| # Decode generated text | |
| generated_text = [] | |
| for i, output in enumerate(outputs): | |
| # Remove input tokens from output | |
| input_length = batch["input_ids"][i].shape[0] | |
| generated_tokens = output[input_length:] | |
| text = self.model.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| generated_text.append(text) | |
| return { | |
| "generated_text": generated_text, | |
| "input_ids": batch["input_ids"], | |
| } | |