import os import base64 import io from typing import List, Optional, Any, Dict import gradio as gr import numpy as np import requests import torch from fastapi import FastAPI, Header, HTTPException from pydantic import BaseModel from PIL import Image from starlette.staticfiles import StaticFiles import threading import json from inference import InferenceService from utils.data_fetch import ensure_dataset_ready # Global state BOOT_STATUS = "starting" DATASET_ROOT: Optional[str] = None def get_artifact_overview(): """Get comprehensive artifact overview.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() return manager.get_artifact_summary() except Exception as e: return {"error": str(e)} def export_artifact_summary(): """Export artifact summary as JSON file.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() summary = manager.get_artifact_summary() # Save to exports directory export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) summary_path = os.path.join(export_dir, "system_summary.json") with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) return summary_path except Exception as e: return None def create_download_package(package_type: str): """Create a downloadable package.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() # Extract package type from the dropdown choice if "complete" in package_type: pkg_type = "complete" elif "splits_only" in package_type: pkg_type = "splits_only" elif "models_only" in package_type: pkg_type = "models_only" else: return f"❌ Invalid package type: {package_type}", get_available_packages() package_path = manager.create_download_package(pkg_type) package_name = os.path.basename(package_path) return f"✅ Package created: {package_name}", get_available_packages() except Exception as e: return f"❌ Failed to create package: {e}", get_available_packages() def get_available_packages(): """Get list of available packages.""" try: export_dir = os.getenv("EXPORT_DIR", "models/exports") packages = [] if os.path.exists(export_dir): for file in os.listdir(export_dir): if file.endswith((".tar.gz", ".zip")): file_path = os.path.join(export_dir, file) packages.append({ "name": file, "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2), "path": file_path, "url": f"/files/{file}" }) return {"packages": packages} except Exception as e: return {"error": str(e)} def get_individual_files(): """Get list of individual downloadable files.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() files = manager.get_downloadable_files() # Group by category categorized = {} for file in files: category = file["category"] if category not in categorized: categorized[category] = [] categorized[category].append(file) return categorized except Exception as e: return {"error": str(e)} def download_all_files(): """Download all files as a ZIP archive.""" try: from utils.artifact_manager import create_artifact_manager manager = create_artifact_manager() files = manager.get_downloadable_files() # Create ZIP with all files export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) zip_path = os.path.join(export_dir, "all_artifacts.zip") import zipfile with zipfile.ZipFile(zip_path, 'w') as zipf: for file in files: if os.path.exists(file["path"]): zipf.write(file["path"], file["name"]) return zip_path except Exception as e: return None def get_training_status(): """Get current training status from the monitor.""" try: from training_monitor import create_monitor monitor = create_monitor() status = monitor.get_status() return status if status else {"status": "no-training"} except Exception as e: return {"status": "error", "error": str(e)} def push_splits_to_hf(token, username): """Push splits to HF Hub.""" if not token or not username: return "❌ Please provide HF token and username" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=token, username=username) result = hf.upload_model("splits", "Dressify-Helper") if result.get("success"): return f"✅ Successfully uploaded splits to {username}/Dressify-Helper" else: return f"❌ Failed to upload splits: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Upload failed: {e}" def push_models_to_hf(token, username): """Push models to HF Hub.""" if not token or not username: return "❌ Please provide HF token and username" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=token, username=username) result = hf.upload_model("models", "dressify-models") if result.get("success"): return f"✅ Successfully uploaded models to {username}/dressify-models" else: return f"❌ Failed to upload models: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Upload failed: {e}" def push_everything_to_hf(token, username): """Push everything to HF Hub.""" if not token or not username: return "❌ Please provide HF token and username" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=token, username=username) result = hf.upload_model("everything", "dressify-complete") if result.get("success"): return f"✅ Successfully uploaded everything to HF Hub" else: return f"❌ Failed to upload everything: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Upload failed: {e}" AI_API_KEY = os.getenv("AI_API_KEY") def require_api_key(x_api_key: Optional[str]): if AI_API_KEY and x_api_key != AI_API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") class EmbedRequest(BaseModel): image_urls: Optional[List[str]] = None images_base64: Optional[List[str]] = None class Item(BaseModel): id: str embedding: Optional[List[float]] = None category: Optional[str] = None image_url: Optional[str] = None class ComposeRequest(BaseModel): items: List[Item] context: Optional[Dict[str, Any]] = None app = FastAPI(title="Dressify Recommendation Service") service = InferenceService() # Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background BOOT_STATUS = "idle" DATASET_ROOT: Optional[str] = None def _background_bootstrap(): global BOOT_STATUS global DATASET_ROOT try: BOOT_STATUS = "preparing-dataset" ds_root = ensure_dataset_ready() DATASET_ROOT = ds_root if not ds_root: BOOT_STATUS = "dataset-not-prepared" return # Prepare splits from official data if missing splits_dir = os.path.join(ds_root, "splits") need_prepare = not ( os.path.isfile(os.path.join(splits_dir, "train.json")) or os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json")) ) if need_prepare: BOOT_STATUS = "creating-splits" os.makedirs(splits_dir, exist_ok=True) from scripts.prepare_polyvore import main as prepare_main os.environ.setdefault("PYTHONWARNINGS", "ignore") import sys argv_bak = sys.argv try: # Use official splits from nondisjoint/ and disjoint/ folders with default size limit (500 samples for faster training) sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--max_samples", "500"] prepare_main() finally: sys.argv = argv_bak # Train if checkpoints are absent export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) resnet_ckpt = os.path.join(export_dir, "resnet_item_embedder_best.pth") vit_ckpt = os.path.join(export_dir, "vit_outfit_model_best.pth") import subprocess if not os.path.exists(resnet_ckpt): BOOT_STATUS = "training-resnet" subprocess.run([ "python", "train_resnet.py", "--data_root", ds_root, "--epochs", "3", "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3", "--out", os.path.join(export_dir, "resnet_item_embedder.pth") ], check=False) if not os.path.exists(vit_ckpt): BOOT_STATUS = "training-vit" subprocess.run([ "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "10", "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5", "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0", "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth") ], check=False) service.reload_models() BOOT_STATUS = "ready" except Exception as e: BOOT_STATUS = f"error: {e}" threading.Thread(target=_background_bootstrap, daemon=True).start() @app.get("/health") def health() -> dict: return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version} @app.get("/model-status") def model_status() -> dict: """Get detailed model loading status.""" return service.get_model_status() @app.post("/reload-models") def reload_models() -> dict: """Force reload models - useful for debugging.""" try: service.force_reload_models() return {"status": "success", "message": "Models reloaded successfully"} except Exception as e: return {"status": "error", "message": str(e)} @app.post("/test-recommend") def test_recommend() -> dict: """Test recommendation with dummy data to debug the issue.""" try: # Create dummy items for testing dummy_items = [ {"id": "test_1", "image": None, "category": "shirt"}, {"id": "test_2", "image": None, "category": "pants"}, {"id": "test_3", "image": None, "category": "shoes"} ] # Try to get recommendations result = service.compose_outfits(dummy_items, {"num_outfits": 1}) return { "status": "success", "model_status": service.get_model_status(), "result": result, "result_length": len(result) if result else 0 } except Exception as e: return {"status": "error", "message": str(e), "model_status": service.get_model_status()} @app.post("/embed") def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict: require_api_key(x_api_key) images: List[Image.Image] = [] if req.image_urls: for url in req.image_urls: resp = requests.get(url, timeout=20) resp.raise_for_status() images.append(Image.open(io.BytesIO(resp.content)).convert("RGB")) if req.images_base64: for b64 in req.images_base64: images.append(Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")) if not images: raise HTTPException(status_code=400, detail="No images provided") embs = service.embed_images(images) return {"embeddings": [e.tolist() for e in embs], "model_version": service.resnet_version} @app.post("/compose") def compose(req: ComposeRequest, x_api_key: Optional[str] = Header(None)) -> dict: require_api_key(x_api_key) items = [ { "id": it.id, "embedding": np.array(it.embedding, dtype=np.float32) if it.embedding is not None else None, "category": it.category, "image_url": it.image_url, } for it in req.items ] outfits = service.compose_outfits(items, context=req.context or {}) return {"outfits": outfits, "version": service.vit_version} @app.get("/artifacts") def artifacts() -> dict: # list exported model artifacts for download export_dir = os.getenv("EXPORT_DIR", "models/exports") files = [] if os.path.isdir(export_dir): for fn in os.listdir(export_dir): if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")): files.append({ "name": fn, "path": f"{export_dir}/{fn}", "url": f"/files/{fn}", }) return {"artifacts": files} # --------- Gradio UI --------- def _load_images_from_files(files: List[str]) -> List[Image.Image]: images: List[Image.Image] = [] for fp in files: try: with Image.open(fp) as im: images.append(im.convert("RGB")) except Exception: continue return images def gradio_embed(files: List[str]): if not files: return "[]" images = _load_images_from_files(files) if not images: return "[]" embs = service.embed_images(images) return str([e.tolist() for e in embs]) def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(245, 245, 245)) -> Image.Image: if not imgs: return Image.new("RGB", (1, height), color=bg) resized = [] for im in imgs: if im.mode != "RGB": im = im.convert("RGB") w, h = im.size scale = height / float(h) nw = max(1, int(w * scale)) resized.append(im.resize((nw, height))) total_w = sum(im.size[0] for im in resized) + pad * (len(resized) + 1) out = Image.new("RGB", (total_w, height + 2 * pad), color=bg) x = pad for im in resized: out.paste(im, (x, pad)) x += im.size[0] + pad return out def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int, outfit_style: str = "casual"): # Check model status first model_status = service.get_model_status() if not model_status["can_recommend"]: error_msg = "❌ Models not ready for recommendations!\n\n" error_msg += "**Model Status:**\n" error_msg += f"- ResNet: {'✅ Loaded' if model_status['resnet_loaded'] else '❌ Not loaded'}\n" error_msg += f"- ViT: {'✅ Loaded' if model_status['vit_loaded'] else '❌ Not loaded'}\n\n" error_msg += "**Errors:**\n" for error in model_status["errors"]: error_msg += f"- {error}\n\n" error_msg += "**Solution:**\n" error_msg += "Please train the models first using the 'Simple Training' or 'Advanced Training' tabs, or ensure trained checkpoints are available." return [], {"error": error_msg, "model_status": model_status} # Return stitched outfit images and a JSON with details if not files: return [], {"error": "No files uploaded"} images = _load_images_from_files(files) if not images: return [], {"error": "Could not load images"} # Build items that allow on-the-fly embedding in service items = [ {"id": f"item_{i}", "image": images[i], "category": None} for i in range(len(images)) ] res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits), "outfit_style": outfit_style}) # Check if compose_outfits returned an error if res and isinstance(res[0], dict) and "error" in res[0]: return [], res[0] # Prepare stitched previews strips: List[Image.Image] = [] for r in res: idxs = [] for iid in r.get("item_ids", []): try: idxs.append(int(str(iid).split("_")[-1])) except Exception: continue imgs = [images[i] for i in idxs if 0 <= i < len(images)] strips.append(_stitch_strip(imgs)) return strips, {"outfits": res} def start_training_advanced( # Dataset size dataset_size: str, # ResNet parameters resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str, resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int, resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float, # ViT parameters vit_epochs: int, vit_batch_size: int, vit_max_samples: int, vit_lr: float, vit_optimizer: str, vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int, vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float, # Advanced parameters use_mixed_precision: bool, channels_last: bool, gradient_clip: float, warmup_epochs: int, scheduler_type: str, early_stopping_patience: int, mining_strategy: str, augmentation_level: str, seed: int ): """Start advanced training with custom parameters.""" # Use global dataset size if not specified if not dataset_size or dataset_size == "full": dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000") if not DATASET_ROOT: return "❌ Dataset not ready. Please wait for bootstrap to complete." log_message = "🚀 Advanced training started with custom parameters! Check the log below for progress." def _runner(): nonlocal log_message try: import subprocess import json export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) # Create custom config files resnet_config = { "model": { "backbone": resnet_backbone, "embedding_dim": resnet_embedding_dim, "pretrained": resnet_use_pretrained, "dropout": resnet_dropout }, "training": { "batch_size": resnet_batch_size, "epochs": resnet_epochs, "lr": resnet_lr, "weight_decay": resnet_weight_decay, "triplet_margin": resnet_triplet_margin, "optimizer": resnet_optimizer, "scheduler": scheduler_type, "warmup_epochs": warmup_epochs, "early_stopping_patience": early_stopping_patience, "use_amp": use_mixed_precision, "channels_last": channels_last, "gradient_clip": gradient_clip }, "data": { "image_size": 224, "augmentation_level": augmentation_level }, "advanced": { "mining_strategy": mining_strategy, "seed": seed } } vit_config = { "model": { "embedding_dim": vit_embedding_dim, "num_layers": vit_num_layers, "num_heads": vit_num_heads, "ff_multiplier": vit_ff_multiplier, "dropout": vit_dropout }, "training": { "batch_size": vit_batch_size, "epochs": vit_epochs, "lr": vit_lr, "weight_decay": vit_weight_decay, "triplet_margin": vit_triplet_margin, "optimizer": vit_optimizer, "scheduler": scheduler_type, "warmup_epochs": warmup_epochs, "early_stopping_patience": early_stopping_patience, "use_amp": use_mixed_precision }, "advanced": { "mining_strategy": mining_strategy, "seed": seed } } # Save configs with open(os.path.join(export_dir, "resnet_config_custom.json"), "w") as f: json.dump(resnet_config, f, indent=2) with open(os.path.join(export_dir, "vit_config_custom.json"), "w") as f: json.dump(vit_config, f, indent=2) # Train ResNet with custom parameters log_message = f"🚀 Starting ResNet training with custom parameters...\n" log_message += f"Dataset Size: {dataset_size} samples\n" log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n" log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n" log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n" # Add dataset size limit if not full dataset_args = [] if dataset_size != "full": dataset_args = ["--max_samples", dataset_size] resnet_cmd = [ "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(resnet_epochs), "--batch_size", str(resnet_batch_size), "--lr", str(resnet_lr), "--weight_decay", str(resnet_weight_decay), "--triplet_margin", str(resnet_triplet_margin), "--embedding_dim", str(resnet_embedding_dim), "--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth") ] + dataset_args if resnet_backbone != "resnet50": resnet_cmd.extend(["--backbone", resnet_backbone]) result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False) if result.returncode == 0: log_message += "✅ ResNet training completed successfully!\n" log_message += f"📊 ResNet Output:\n{result.stdout}\n\n" else: log_message += f"❌ ResNet training failed: {result.stderr}\n\n" return log_message # Wait a moment for file system sync and ensure ResNet is fully saved import time time.sleep(3) log_message += "⏳ Waiting for ResNet checkpoint to be fully saved...\n" # Verify ResNet checkpoint exists before proceeding resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder_custom.pth") if not os.path.exists(resnet_checkpoint): log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n" log_message += "Cannot proceed with ViT training without ResNet embeddings.\n" return log_message log_message += f"✅ ResNet checkpoint verified: {resnet_checkpoint}\n" # Train ViT with custom parameters log_message += f"🚀 Starting ViT training with custom parameters...\n" log_message += f"Dataset Size: {dataset_size} samples\n" log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n" log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n" log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n" vit_cmd = [ "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs), "--batch_size", str(vit_batch_size), "--max_samples", str(vit_max_samples), "--lr", str(vit_lr), "--weight_decay", str(vit_weight_decay), "--triplet_margin", str(vit_triplet_margin), "--embedding_dim", str(vit_embedding_dim), "--export", os.path.join(export_dir, "vit_outfit_model_custom.pth") ] + dataset_args result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False) if result.returncode == 0: log_message += "✅ ViT training completed successfully!\n" log_message += f"📊 ViT Output:\n{result.stdout}\n\n" log_message += "🎉 All training completed! Models saved to models/exports/\n" log_message += "🔄 Reloading models for inference...\n" service.reload_models() # Check if models loaded successfully model_status = service.get_model_status() if model_status["can_recommend"]: log_message += "✅ Models reloaded and ready for inference!\n" log_message += "🎉 You can now generate outfit recommendations!\n" else: log_message += "⚠️ Models reloaded but validation failed!\n" log_message += "**Model Status:**\n" log_message += f"- ResNet: {'✅ Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n" log_message += f"- ViT: {'✅ Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n" for error in model_status["errors"]: log_message += f"- {error}\n" # Auto-upload to HF Hub if token is available hf_token = os.getenv("HF_TOKEN") if hf_token: log_message += "📤 Auto-uploading artifacts to Hugging Face Hub...\n" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=hf_token, username="Stylique") result = hf.upload_model("everything", "dressify-complete") if result.get("success"): log_message += "✅ Successfully uploaded to HF Hub!\n" log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n" log_message += "🔗 Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n" else: log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n" except Exception as e: log_message += f"⚠️ Auto-upload failed: {str(e)}\n" else: log_message += "💡 Set HF_TOKEN env var for automatic uploads\n" else: log_message += f"❌ ViT training failed: {result.stderr}\n" except Exception as e: log_message += f"\n❌ Training error: {str(e)}" threading.Thread(target=_runner, daemon=True).start() return log_message def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int): """Start simple training with basic parameters.""" # Use global dataset size if not specified if not dataset_size or dataset_size == "full": dataset_size = os.getenv("DATASET_SIZE_LIMIT", "2000") log_message = f"Starting training on {dataset_size} samples..." def _runner(): nonlocal log_message try: import subprocess if not DATASET_ROOT: log_message = "Dataset not ready." return export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) log_message = f"Training ResNet on {dataset_size} samples...\n" # Add dataset size limit if not full dataset_args = [] if dataset_size != "full": dataset_args = ["--max_samples", dataset_size] # Train ResNet first and wait for completion log_message += f"\n🚀 Starting ResNet training on {dataset_size} samples...\n" resnet_result = subprocess.run([ "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs), "--batch_size", "4", "--lr", "1e-3", "--early_stopping_patience", "3", "--out", os.path.join(export_dir, "resnet_item_embedder.pth") ] + dataset_args, capture_output=True, text=True, check=False) if resnet_result.returncode == 0: log_message += "✅ ResNet training completed successfully!\n" log_message += f"📊 ResNet Output:\n{resnet_result.stdout}\n" else: log_message += f"❌ ResNet training failed: {resnet_result.stderr}\n" return log_message # Wait a moment for file system sync import time time.sleep(2) # Verify ResNet checkpoint exists before proceeding resnet_checkpoint = os.path.join(export_dir, "resnet_item_embedder.pth") if not os.path.exists(resnet_checkpoint): log_message += f"❌ ResNet checkpoint not found at {resnet_checkpoint}\n" log_message += "Cannot proceed with ViT training without ResNet embeddings.\n" return log_message log_message += f"✅ ResNet checkpoint verified: {resnet_checkpoint}\n" log_message += f"\n🚀 Starting ViT training on {dataset_size} samples...\n" vit_result = subprocess.run([ "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs), "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5", "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0", "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth") ] + dataset_args, capture_output=True, text=True, check=False) if vit_result.returncode == 0: log_message += "✅ ViT training completed successfully!\n" log_message += f"📊 ViT Output:\n{vit_result.stdout}\n" else: log_message += f"❌ ViT training failed: {vit_result.stderr}\n" return log_message service.reload_models() # Check if models loaded successfully model_status = service.get_model_status() if model_status["can_recommend"]: log_message += "\n✅ Training completed! Models reloaded and ready for inference.\n" log_message += "🎉 You can now generate outfit recommendations!\n" else: log_message += "\n⚠️ Training completed but models failed to load properly!\n" log_message += "**Model Status:**\n" log_message += f"- ResNet: {'✅ Loaded' if model_status['resnet_loaded'] else '❌ Failed'}\n" log_message += f"- ViT: {'✅ Loaded' if model_status['vit_loaded'] else '❌ Failed'}\n" for error in model_status["errors"]: log_message += f"- {error}\n" log_message += "\nArtifacts saved to models/exports/" # Auto-upload to HF Hub if token is available hf_token = os.getenv("HF_TOKEN") if hf_token: log_message += "\n📤 Auto-uploading artifacts to Hugging Face Hub...\n" try: from utils.hf_utils import HFModelManager hf = HFModelManager(token=hf_token, username="Stylique") result = hf.upload_model("everything", "dressify-complete") if result.get("success"): log_message += "✅ Successfully uploaded to HF Hub!\n" log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n" log_message += "🔗 Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n" else: log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n" except Exception as e: log_message += f"⚠️ Auto-upload failed: {str(e)}\n" else: log_message += "\n💡 Set HF_TOKEN env var for automatic uploads\n" except Exception as e: log_message += f"\nError: {e}" threading.Thread(target=_runner, daemon=True).start() return log_message with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo: gr.Markdown("## 🏆 Dressify – Advanced Outfit Recommendation System\n*Research-grade, self-contained outfit recommendation with comprehensive training controls*") gr.Markdown("💡 **Pro Tip**: Start with 2000 samples for quick testing, then increase to 50000+ for production training!") with gr.Tab("🎨 Recommend"): inp2 = gr.Files(label="Upload wardrobe images", file_types=["image"], file_count="multiple") with gr.Row(): occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion") weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather") outfit_style = gr.Dropdown(choices=["casual", "smart_casual", "formal", "sporty", "traditional"], value="casual", label="Outfit Style") with gr.Row(): num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Number of outfits") out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320) out_json = gr.JSON(label="Outfit Details") btn2 = gr.Button("Generate Outfits", variant="primary") btn2.click(fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits, outfit_style], outputs=[out_gallery, out_json]) with gr.Tab("🔬 Advanced Training"): gr.Markdown("### 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.") # Global Dataset Size Control with gr.Row(): gr.Markdown("#### 🎯 **Global Dataset Size Control**") gr.Markdown("**Note**: Initial bootstrap downloads full dataset (required). Use 'Apply' button to limit splits for testing.") with gr.Row(): gr.Markdown("#### 📊 **Current Behavior**") gr.Markdown("• **Bootstrap**: Downloads full dataset (53K outfits) + generates splits with **500 samples by default**\n• **Training**: Uses 500 samples (ultra-fast training!)\n• **Apply Button**: Regenerates splits with your selected size limit") with gr.Row(): global_dataset_size = gr.Dropdown( choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"], value="500", label="Global Dataset Size (Affects Prep + Training)" ) gr.Markdown("**160**: Ultra-fast testing (~30 sec prep, ~1-2 min training)\n**2000**: Fast testing (~1-2 min prep, ~2-5 min training)\n**5000**: Fast testing (~2-3 min prep, ~5-10 min training)\n**10000**: Good testing (~3-5 min prep, ~10-20 min training)\n**full**: Production (~5-10 min prep, ~1-4 hours training)") with gr.Row(): # Apply dataset size button apply_size_btn = gr.Button("🔄 Apply Dataset Size & Regenerate Splits", variant="primary") size_status = gr.Textbox(label="Dataset Size Status", value="Dataset size: 500 samples (click Apply to regenerate splits)", interactive=False) # Current dataset info gr.Markdown("#### 📊 **Current Dataset Status**") gr.Markdown("• **Full dataset downloaded**: 53,306 outfits (required for system)\n• **Splits generated**: **500 samples by default** (ultra-fast training!)\n• **Training will use**: 500 samples (ultra-fast training!)\n• **Scale up**: Use Apply button to increase to larger sizes") def apply_dataset_size(size: str): """Apply global dataset size and regenerate splits.""" try: if size == "full": return f"✅ Using full dataset ({size}) - no size limit applied" # Call the dataset preparation with size limit import subprocess import os # Set environment variable for dataset size os.environ["DATASET_SIZE_LIMIT"] = size # Check if script exists script_path = "scripts/prepare_polyvore.py" if not os.path.exists(script_path): return f"❌ Script not found: {script_path}" # Regenerate splits with size limit using subprocess cmd = [ "python", script_path, "--root", "/home/user/app/data/Polyvore", "--out", "/home/user/app/data/Polyvore/splits", "--max_samples", size ] print(f"Running command: {' '.join(cmd)}") print(f"Current working directory: {os.getcwd()}") # Run from the correct directory result = subprocess.run(cmd, capture_output=True, text=True, check=False, cwd="/home/user/app") if result.returncode == 0: return f"✅ Successfully regenerated splits with {size} samples limit" else: error_msg = f"❌ Failed to regenerate splits:\n" error_msg += f"Return code: {result.returncode}\n" error_msg += f"STDOUT: {result.stdout}\n" error_msg += f"STDERR: {result.stderr}" return error_msg except Exception as e: return f"❌ Failed to apply dataset size: {str(e)}" apply_size_btn.click(fn=apply_dataset_size, inputs=[global_dataset_size], outputs=[size_status]) with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### 📊 Dataset Size Control") gr.Markdown("Start small for testing, increase for production training") dataset_size = gr.Dropdown( choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"], value="500", label="Training Dataset Size" ) gr.Markdown("**2000**: Quick testing (~2-5 min)\n**5000**: Fast validation (~5-10 min)\n**10000**: Good validation (~10-20 min)\n**25000+**: Production training") with gr.Column(scale=1): gr.Markdown("#### 🖼️ ResNet Item Embedder") # Model architecture resnet_backbone = gr.Dropdown( choices=["resnet50", "resnet101"], value="resnet50", label="Backbone Architecture" ) resnet_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") resnet_use_pretrained = gr.Checkbox(value=True, label="Use ImageNet Pretrained") resnet_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") # Training parameters resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs") resnet_batch_size = gr.Slider(4, 128, value=4, step=4, label="Batch Size") resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate") resnet_optimizer = gr.Dropdown( choices=["adamw", "adam", "sgd", "rmsprop"], value="adamw", label="Optimizer" ) resnet_weight_decay = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="Weight Decay") resnet_triplet_margin = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Triplet Margin") with gr.Column(scale=1): gr.Markdown("#### 🧠 ViT Outfit Encoder") # Model architecture vit_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") vit_num_layers = gr.Slider(2, 12, value=6, step=1, label="Transformer Layers") vit_num_heads = gr.Slider(4, 16, value=8, step=2, label="Attention Heads") vit_ff_multiplier = gr.Slider(2, 8, value=4, step=1, label="Feed-Forward Multiplier") vit_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") # Training parameters vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs") vit_batch_size = gr.Slider(2, 64, value=4, step=2, label="Batch Size") vit_max_samples = gr.Slider(100, 5000, value=500, step=100, label="Max Training Samples") vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate") vit_optimizer = gr.Dropdown( choices=["adamw", "adam", "sgd", "rmsprop"], value="adamw", label="Optimizer" ) vit_weight_decay = gr.Slider(1e-4, 1e-1, value=5e-2, step=1e-4, label="Weight Decay") vit_triplet_margin = gr.Slider(0.1, 1.0, value=0.3, step=0.05, label="Triplet Margin") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### ⚙️ Advanced Training Settings") # Hardware optimization use_mixed_precision = gr.Checkbox(value=True, label="Mixed Precision (AMP)") channels_last = gr.Checkbox(value=True, label="Channels Last Memory Format") gradient_clip = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Gradient Clipping") # Learning rate scheduling warmup_epochs = gr.Slider(0, 10, value=3, step=1, label="Warmup Epochs") scheduler_type = gr.Dropdown( choices=["cosine", "step", "plateau", "linear"], value="cosine", label="Learning Rate Scheduler" ) early_stopping_patience = gr.Slider(5, 20, value=10, step=1, label="Early Stopping Patience") # Training strategy mining_strategy = gr.Dropdown( choices=["semi_hard", "hardest", "random"], value="semi_hard", label="Triplet Mining Strategy" ) augmentation_level = gr.Dropdown( choices=["minimal", "standard", "aggressive"], value="standard", label="Data Augmentation Level" ) seed = gr.Slider(0, 9999, value=42, step=1, label="Random Seed") with gr.Column(scale=1): gr.Markdown("#### 🚀 Training Control") # Quick training gr.Markdown("**Quick Training (Basic Parameters)**") epochs_res = gr.Slider(1, 50, value=3, step=1, label="ResNet epochs") epochs_vit = gr.Slider(1, 100, value=3, step=1, label="ViT epochs") start_btn = gr.Button("🚀 Start Quick Training", variant="secondary") # Advanced training gr.Markdown("**Advanced Training (Custom Parameters)**") start_advanced_btn = gr.Button("🎯 Start Advanced Training", variant="primary") # Training log train_log = gr.Textbox(label="Training Log", lines=15, max_lines=20) # Status gr.Markdown("**Training Status**") training_status = gr.Textbox(label="Status", value="Ready to train", interactive=False) # Event handlers start_btn.click( fn=start_training_simple, inputs=[dataset_size, epochs_res, epochs_vit], outputs=train_log ) start_advanced_btn.click( fn=start_training_advanced, inputs=[ # Dataset size dataset_size, # ResNet parameters resnet_epochs, resnet_batch_size, resnet_lr, resnet_optimizer, resnet_weight_decay, resnet_triplet_margin, resnet_embedding_dim, resnet_backbone, resnet_use_pretrained, resnet_dropout, # ViT parameters vit_epochs, vit_batch_size, vit_max_samples, vit_lr, vit_optimizer, vit_weight_decay, vit_triplet_margin, vit_embedding_dim, vit_num_layers, vit_num_heads, vit_ff_multiplier, vit_dropout, # Advanced parameters use_mixed_precision, channels_last, gradient_clip, warmup_epochs, scheduler_type, early_stopping_patience, mining_strategy, augmentation_level, seed ], outputs=train_log ) with gr.Tab("📦 Artifact Management"): gr.Markdown("### 🎯 Comprehensive Artifact Management\nManage, package, and upload all system artifacts to Hugging Face Hub.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### 📊 Artifact Overview") artifact_overview = gr.JSON(label="System Artifacts", value=get_artifact_overview) refresh_overview = gr.Button("🔄 Refresh Overview") refresh_overview.click(fn=get_artifact_overview, inputs=[], outputs=artifact_overview) gr.Markdown("#### 📦 Create Packages") package_type = gr.Dropdown( choices=["complete", "splits_only", "models_only"], value="complete", label="Package Type" ) create_package_btn = gr.Button("📦 Create Package") package_result = gr.Textbox(label="Package Result", interactive=False) available_packages = gr.JSON(label="Available Packages", value=get_available_packages) create_package_btn.click( fn=create_download_package, inputs=[package_type], outputs=[package_result, available_packages] ) with gr.Column(scale=1): gr.Markdown("#### 🚀 Hugging Face Hub Integration") gr.Markdown("💡 **Pro Tip**: Set `HF_TOKEN` environment variable for automatic uploads after training!") hf_token = gr.Textbox(label="HF Token", type="password", placeholder="hf_...") hf_username = gr.Textbox(label="Username", placeholder="your-username") with gr.Row(): push_splits_btn = gr.Button("📤 Push Splits", variant="secondary") push_models_btn = gr.Button("📤 Push Models", variant="secondary") push_everything_btn = gr.Button("📤 Push Everything", variant="primary") hf_result = gr.Textbox(label="Upload Result", interactive=False, lines=3) push_splits_btn.click(fn=push_splits_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) push_models_btn.click(fn=push_models_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) push_everything_btn.click(fn=push_everything_to_hf, inputs=[hf_token, hf_username], outputs=hf_result) gr.Markdown("#### 📥 Download Management") individual_files = gr.JSON(label="Individual Files", value=get_individual_files) download_all_btn = gr.Button("📥 Download All as ZIP") download_result = gr.Textbox(label="Download Result", interactive=False) download_all_btn.click(fn=download_all_files, inputs=[], outputs=download_result) with gr.Tab("🔧 Simple Training"): gr.Markdown("### 🚀 Quick Training with Default Parameters\nFast training with proven configurations for immediate results.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### 📊 Dataset Size Control") gr.Markdown("Start small for testing, increase for production training") dataset_size = gr.Dropdown( choices=["160", "500", "2000", "5000", "10000", "25000", "50000", "full"], value="500", label="Training Dataset Size" ) gr.Markdown("**2000**: Quick testing (~2-5 min)\n**5000**: Fast validation (~5-10 min)\n**10000**: Good validation (~10-20 min)\n**25000+**: Production training") with gr.Column(scale=1): gr.Markdown("#### ⚙️ Training Parameters") epochs_res = gr.Slider(1, 50, value=3, step=1, label="ResNet epochs") epochs_vit = gr.Slider(1, 100, value=3, step=1, label="ViT epochs") train_log = gr.Textbox(label="Training Log", lines=10) start_btn = gr.Button("Start Training") start_btn.click(fn=start_training_simple, inputs=[dataset_size, epochs_res, epochs_vit], outputs=train_log) with gr.Tab("📊 Embed (Debug)"): inp = gr.Files(label="Upload Items (multiple images)") out = gr.Textbox(label="Embeddings (JSON)") btn = gr.Button("Compute Embeddings") btn.click(fn=gradio_embed, inputs=inp, outputs=out) with gr.Tab("📈 Status"): gr.Markdown("### 🚦 System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.") status = gr.Textbox(label="Bootstrap Status", value=lambda: BOOT_STATUS) refresh_status = gr.Button("🔄 Refresh Status") refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status) # Model Status gr.Markdown("#### 🤖 Model Status") model_status = gr.JSON(label="Model Loading Status", value=lambda: service.get_model_status()) refresh_models = gr.Button("🔄 Refresh Model Status") refresh_models.click(fn=lambda: service.get_model_status(), inputs=[], outputs=model_status) # System info gr.Markdown("#### 💻 System Information") device_info = gr.Textbox(label="Device", value=lambda: f"Device: {service.device}") resnet_version = gr.Textbox(label="ResNet Version", value=lambda: f"ResNet: {service.resnet_version}") vit_version = gr.Textbox(label="ViT Version", value=lambda: f"ViT: {service.vit_version}") # Health check gr.Markdown("#### 🏥 Health Check") health_btn = gr.Button("🔍 Check Health") health_status = gr.Textbox(label="Health Status", value="Click to check") def check_health(): try: health = app.get("/health") return f"✅ System Healthy - {health}" except Exception as e: return f"❌ Health Check Failed: {str(e)}" health_btn.click(fn=check_health, inputs=[], outputs=health_status) try: # Mount Gradio onto FastAPI root path (disable SSR to avoid stray port fetches) demo.queue() app = gr.mount_gradio_app(app, demo, path="/") except Exception: # In case mounting fails in certain runners, we still want FastAPI to be available pass # Mount static files for direct artifact download export_dir = os.getenv("EXPORT_DIR", "models/exports") os.makedirs(export_dir, exist_ok=True) try: app.mount("/files", StaticFiles(directory=export_dir), name="files") except Exception: pass if __name__ == "__main__": # Local/Space run demo.queue().launch(ssr_mode=False)