""" Advanced Training UI Components for Dressify Provides comprehensive parameter controls for both ResNet and ViT training """ import gradio as gr import os import subprocess import threading import json from typing import Dict, Any def create_advanced_training_interface(): """Create the advanced training interface with all parameter controls.""" with gr.Blocks(title="Advanced Training Control") as training_interface: gr.Markdown("## 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### πŸ–ΌοΈ ResNet Item Embedder") # Model architecture resnet_backbone = gr.Dropdown( choices=["resnet50", "resnet101"], value="resnet50", label="Backbone Architecture" ) resnet_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") resnet_use_pretrained = gr.Checkbox(value=True, label="Use ImageNet Pretrained") resnet_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") # Training parameters resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs") resnet_batch_size = gr.Slider(8, 128, value=64, step=8, label="Batch Size") resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate") resnet_optimizer = gr.Dropdown( choices=["adamw", "adam", "sgd", "rmsprop"], value="adamw", label="Optimizer" ) resnet_weight_decay = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="Weight Decay") resnet_triplet_margin = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Triplet Margin") with gr.Column(scale=1): gr.Markdown("#### 🧠 ViT Outfit Encoder") # Model architecture vit_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") vit_num_layers = gr.Slider(2, 12, value=6, step=1, label="Transformer Layers") vit_num_heads = gr.Slider(4, 16, value=8, step=2, label="Attention Heads") vit_ff_multiplier = gr.Slider(2, 8, value=4, step=1, label="Feed-Forward Multiplier") vit_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") # Training parameters vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs") vit_batch_size = gr.Slider(4, 64, value=32, step=4, label="Batch Size") vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate") vit_optimizer = gr.Dropdown( choices=["adamw", "adam", "sgd", "rmsprop"], value="adamw", label="Optimizer" ) vit_weight_decay = gr.Slider(1e-4, 1e-1, value=5e-2, step=1e-4, label="Weight Decay") vit_triplet_margin = gr.Slider(0.1, 1.0, value=0.3, step=0.05, label="Triplet Margin") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### βš™οΈ Advanced Training Settings") # Hardware optimization use_mixed_precision = gr.Checkbox(value=True, label="Mixed Precision (AMP)") channels_last = gr.Checkbox(value=True, label="Channels Last Memory Format") gradient_clip = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Gradient Clipping") # Learning rate scheduling warmup_epochs = gr.Slider(0, 10, value=3, step=1, label="Warmup Epochs") scheduler_type = gr.Dropdown( choices=["cosine", "step", "plateau", "linear"], value="cosine", label="Learning Rate Scheduler" ) early_stopping_patience = gr.Slider(5, 20, value=10, step=1, label="Early Stopping Patience") # Training strategy mining_strategy = gr.Dropdown( choices=["semi_hard", "hardest", "random"], value="semi_hard", label="Triplet Mining Strategy" ) augmentation_level = gr.Dropdown( choices=["minimal", "standard", "aggressive"], value="standard", label="Data Augmentation Level" ) seed = gr.Slider(0, 9999, value=42, step=1, label="Random Seed") with gr.Column(scale=1): gr.Markdown("#### πŸš€ Training Control") # Quick training gr.Markdown("**Quick Training (Basic Parameters)**") epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs") epochs_vit = gr.Slider(1, 100, value=20, step=1, label="ViT epochs") start_btn = gr.Button("πŸš€ Start Quick Training", variant="secondary") # Advanced training gr.Markdown("**Advanced Training (Custom Parameters)**") start_advanced_btn = gr.Button("🎯 Start Advanced Training", variant="primary") # Training log train_log = gr.Textbox(label="Training Log", lines=15, max_lines=20) # Status gr.Markdown("**Training Status**") training_status = gr.Textbox(label="Status", value="Ready to train", interactive=False) return training_interface, { 'resnet_backbone': resnet_backbone, 'resnet_embedding_dim': resnet_embedding_dim, 'resnet_use_pretrained': resnet_use_pretrained, 'resnet_dropout': resnet_dropout, 'resnet_epochs': resnet_epochs, 'resnet_batch_size': resnet_batch_size, 'resnet_lr': resnet_lr, 'resnet_optimizer': resnet_optimizer, 'resnet_weight_decay': resnet_weight_decay, 'resnet_triplet_margin': resnet_triplet_margin, 'vit_embedding_dim': vit_embedding_dim, 'vit_num_layers': vit_num_layers, 'vit_num_heads': vit_num_heads, 'vit_ff_multiplier': vit_ff_multiplier, 'vit_dropout': vit_dropout, 'vit_epochs': vit_epochs, 'vit_batch_size': vit_batch_size, 'vit_lr': vit_lr, 'vit_optimizer': vit_optimizer, 'vit_weight_decay': vit_weight_decay, 'vit_triplet_margin': vit_triplet_margin, 'use_mixed_precision': use_mixed_precision, 'channels_last': channels_last, 'gradient_clip': gradient_clip, 'warmup_epochs': warmup_epochs, 'scheduler_type': scheduler_type, 'early_stopping_patience': early_stopping_patience, 'mining_strategy': mining_strategy, 'augmentation_level': augmentation_level, 'seed': seed, 'start_btn': start_btn, 'start_advanced_btn': start_advanced_btn, 'train_log': train_log, 'training_status': training_status } def start_advanced_training( # ResNet parameters resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str, resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int, resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float, # ViT parameters vit_epochs: int, vit_batch_size: int, vit_lr: float, vit_optimizer: str, vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int, vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float, # Advanced parameters use_mixed_precision: bool, channels_last: bool, gradient_clip: float, warmup_epochs: int, scheduler_type: str, early_stopping_patience: int, mining_strategy: str, augmentation_level: str, seed: int, dataset_root: str = None ): """Start advanced training with custom parameters.""" if not dataset_root: dataset_root = os.getenv("POLYVORE_ROOT", "data/Polyvore") if not os.path.exists(dataset_root): return "❌ Dataset not ready. Please wait for bootstrap to complete." def _runner(): try: import subprocess import json export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) # Create custom config files resnet_config = { "model": { "backbone": resnet_backbone, "embedding_dim": resnet_embedding_dim, "pretrained": resnet_use_pretrained, "dropout": resnet_dropout }, "training": { "batch_size": resnet_batch_size, "epochs": resnet_epochs, "lr": resnet_lr, "weight_decay": resnet_weight_decay, "triplet_margin": resnet_triplet_margin, "optimizer": resnet_optimizer, "scheduler": scheduler_type, "warmup_epochs": warmup_epochs, "early_stopping_patience": early_stopping_patience, "use_amp": use_mixed_precision, "channels_last": channels_last, "gradient_clip": gradient_clip }, "data": { "image_size": 224, "augmentation_level": augmentation_level }, "advanced": { "mining_strategy": mining_strategy, "seed": seed } } vit_config = { "model": { "embedding_dim": vit_embedding_dim, "num_layers": vit_num_layers, "num_heads": vit_num_heads, "ff_multiplier": vit_ff_multiplier, "dropout": vit_dropout }, "training": { "batch_size": vit_batch_size, "epochs": vit_epochs, "lr": vit_lr, "weight_decay": vit_weight_decay, "triplet_margin": vit_triplet_margin, "optimizer": vit_optimizer, "scheduler": scheduler_type, "warmup_epochs": warmup_epochs, "early_stopping_patience": early_stopping_patience, "use_amp": use_mixed_precision }, "advanced": { "mining_strategy": mining_strategy, "seed": seed } } # Save configs with open(os.path.join(export_dir, "resnet_config_custom.json"), "w") as f: json.dump(resnet_config, f, indent=2) with open(os.path.join(export_dir, "vit_config_custom.json"), "w") as f: json.dump(vit_config, f, indent=2) # Train ResNet with custom parameters train_log.value = f"πŸš€ Starting ResNet training with custom parameters...\n" train_log.value += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n" train_log.value += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n" train_log.value += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n" resnet_cmd = [ "python", "train_resnet.py", "--data_root", dataset_root, "--epochs", str(resnet_epochs), "--batch_size", str(resnet_batch_size), "--lr", str(resnet_lr), "--weight_decay", str(resnet_weight_decay), "--triplet_margin", str(resnet_triplet_margin), "--embedding_dim", str(resnet_embedding_dim), "--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth") ] if resnet_backbone != "resnet50": resnet_cmd.extend(["--backbone", resnet_backbone]) result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False) if result.returncode == 0: train_log.value += "βœ… ResNet training completed successfully!\n\n" else: train_log.value += f"❌ ResNet training failed: {result.stderr}\n\n" return # Train ViT with custom parameters train_log.value += f"πŸš€ Starting ViT training with custom parameters...\n" train_log.value += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n" train_log.value += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n" train_log.value += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n" vit_cmd = [ "python", "train_vit_triplet.py", "--data_root", dataset_root, "--epochs", str(vit_epochs), "--batch_size", str(vit_batch_size), "--lr", str(vit_lr), "--weight_decay", str(vit_weight_decay), "--triplet_margin", str(vit_triplet_margin), "--embedding_dim", str(vit_embedding_dim), "--export", os.path.join(export_dir, "vit_outfit_model_custom.pth") ] result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False) if result.returncode == 0: train_log.value += "βœ… ViT training completed successfully!\n\n" train_log.value += "πŸŽ‰ All training completed! Models saved to models/exports/\n" train_log.value += "πŸ”„ Reloading models for inference...\n" # Note: service.reload_models() would need to be called from main app train_log.value += "βœ… Models reloaded and ready for inference!\n" else: train_log.value += f"❌ ViT training failed: {result.stderr}\n" except Exception as e: train_log.value += f"\n❌ Training error: {str(e)}" threading.Thread(target=_runner, daemon=True).start() return "πŸš€ Advanced training started with custom parameters! Check the log below for progress." def start_simple_training(res_epochs: int, vit_epochs: int, dataset_root: str = None): """Start simple training with basic parameters.""" if not dataset_root: dataset_root = os.getenv("POLYVORE_ROOT", "data/Polyvore") def _runner(): try: import subprocess if not os.path.exists(dataset_root): train_log.value = "Dataset not ready." return export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) train_log.value = "Training ResNet…\n" subprocess.run([ "python", "train_resnet.py", "--data_root", dataset_root, "--epochs", str(res_epochs), "--out", os.path.join(export_dir, "resnet_item_embedder.pth") ], check=False) train_log.value += "\nTraining ViT (triplet)…\n" subprocess.run([ "python", "train_vit_triplet.py", "--data_root", dataset_root, "--epochs", str(vit_epochs), "--export", os.path.join(export_dir, "vit_outfit_model.pth") ], check=False) train_log.value += "\nDone. Artifacts in models/exports." except Exception as e: train_log.value += f"\nError: {e}" threading.Thread(target=_runner, daemon=True).start() return "Started" # Example usage if __name__ == "__main__": interface, components = create_advanced_training_interface() # Set up event handlers components['start_btn'].click( fn=start_simple_training, inputs=[components['resnet_epochs'], components['vit_epochs']], outputs=components['train_log'] ) components['start_advanced_btn'].click( fn=start_advanced_training, inputs=[ components['resnet_epochs'], components['resnet_batch_size'], components['resnet_lr'], components['resnet_optimizer'], components['resnet_weight_decay'], components['resnet_triplet_margin'], components['resnet_embedding_dim'], components['resnet_backbone'], components['resnet_use_pretrained'], components['resnet_dropout'], components['vit_epochs'], components['vit_batch_size'], components['vit_lr'], components['vit_optimizer'], components['vit_weight_decay'], components['vit_triplet_margin'], components['vit_embedding_dim'], components['vit_num_layers'], components['vit_num_heads'], components['vit_ff_multiplier'], components['vit_dropout'], components['use_mixed_precision'], components['channels_last'], components['gradient_clip'], components['warmup_epochs'], components['scheduler_type'], components['early_stopping_patience'], components['mining_strategy'], components['augmentation_level'], components['seed'] ], outputs=components['train_log'] ) interface.launch()