#!/usr/bin/env python3 """ Simple test script to verify training components work. Run this to test if the system is ready for training. """ import os import sys import torch def test_imports(): """Test if all required modules can be imported.""" print("๐Ÿ” Testing imports...") try: from models.resnet_embedder import ResNetItemEmbedder print("โœ… ResNet embedder imported successfully") except Exception as e: print(f"โŒ Failed to import ResNet embedder: {e}") return False try: from models.vit_outfit import OutfitCompatibilityModel print("โœ… ViT outfit model imported successfully") except Exception as e: print(f"โŒ Failed to import ViT outfit model: {e}") return False try: from data.polyvore import PolyvoreTripletDataset print("โœ… Polyvore dataset imported successfully") except Exception as e: print(f"โŒ Failed to import Polyvore dataset: {e}") return False try: from utils.transforms import build_train_transforms print("โœ… Transforms imported successfully") except Exception as e: print(f"โŒ Failed to import transforms: {e}") return False return True def test_models(): """Test if models can be created and run forward pass.""" print("\n๐Ÿ—๏ธ Testing model creation...") try: # Test ResNet embedder device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") resnet = ResNetItemEmbedder(embedding_dim=512).to(device) print(f"โœ… ResNet created with {sum(p.numel() for p in resnet.parameters()):,} parameters") # Test forward pass dummy_input = torch.randn(2, 3, 224, 224).to(device) with torch.no_grad(): output = resnet(dummy_input) print(f"โœ… ResNet forward pass: input {dummy_input.shape} -> output {output.shape}") # Test ViT outfit model vit = OutfitCompatibilityModel(embedding_dim=512).to(device) print(f"โœ… ViT created with {sum(p.numel() for p in vit.parameters()):,} parameters") # Test forward pass dummy_tokens = torch.randn(2, 4, 512).to(device) with torch.no_grad(): output = vit(dummy_tokens) print(f"โœ… ViT forward pass: input {dummy_tokens.shape} -> output {output.shape}") return True except Exception as e: print(f"โŒ Model test failed: {e}") return False def test_dataset(): """Test if dataset can be loaded (if available).""" print("\n๐Ÿ“Š Testing dataset loading...") data_root = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore") splits_dir = os.path.join(data_root, "splits") train_file = os.path.join(splits_dir, "train.json") if not os.path.exists(train_file): print(f"โš ๏ธ Training data not found at {train_file}") print("๐Ÿ’ก Dataset preparation may be needed") return True # Not a failure, just not ready try: dataset = PolyvoreTripletDataset(data_root, split="train") print(f"โœ… Dataset loaded successfully: {len(dataset)} samples") # Test getting one sample if len(dataset) > 0: sample = dataset[0] print(f"โœ… Sample loaded: {len(sample)} tensors with shapes {[s.shape for s in sample]}") return True except Exception as e: print(f"โŒ Dataset test failed: {e}") return False def test_training_components(): """Test if training components can be created.""" print("\n๐Ÿš€ Testing training components...") try: from torch.utils.data import DataLoader from torch.optim import AdamW from torch.nn import TripletMarginLoss # Test optimizer creation device = "cuda" if torch.cuda.is_available() else "cpu" model = ResNetItemEmbedder(embedding_dim=512).to(device) optimizer = AdamW(model.parameters(), lr=1e-3) print("โœ… Optimizer created successfully") # Test loss function criterion = TripletMarginLoss(margin=0.2) print("โœ… Loss function created successfully") return True except Exception as e: print(f"โŒ Training components test failed: {e}") return False def main(): """Run all tests.""" print("๐Ÿงช Starting Dressify Training System Tests\n") tests = [ ("Imports", test_imports), ("Models", test_models), ("Dataset", test_dataset), ("Training Components", test_training_components), ] results = [] for test_name, test_func in tests: try: result = test_func() results.append((test_name, result)) except Exception as e: print(f"โŒ {test_name} test crashed: {e}") results.append((test_name, False)) # Summary print("\n" + "="*50) print("๐Ÿ“Š TEST RESULTS SUMMARY") print("="*50) passed = 0 total = len(results) for test_name, result in results: status = "โœ… PASS" if result else "โŒ FAIL" print(f"{test_name:20} {status}") if result: passed += 1 print("="*50) print(f"Overall: {passed}/{total} tests passed") if passed == total: print("๐ŸŽ‰ All tests passed! System is ready for training.") return True else: print("โš ๏ธ Some tests failed. Please check the errors above.") return False if __name__ == "__main__": success = main() sys.exit(0 if success else 1)