Garment3dKabeer / post_install.py
Stylique's picture
Upload 4 files
12a663d verified
raw
history blame
7.36 kB
#!/usr/bin/env python3
"""
Post-install script for Hugging Face Spaces
This script installs complex dependencies that need PyTorch to be available first
"""
import os
import sys
import subprocess
import shutil
from pathlib import Path
def run_command(command, cwd=None, env=None):
"""Run a shell command and return the result"""
print(f"Running: {command}")
result = subprocess.run(command, shell=True, cwd=cwd, capture_output=True, text=True, env=env)
if result.returncode != 0:
print(f"Error running command: {command}")
print(f"Error output: {result.stderr}")
return False
print(f"Success: {command}")
return True
def check_pytorch_cuda():
"""Check if PyTorch is installed with correct CUDA version"""
try:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"cuDNN version: {torch.backends.cudnn.version()}")
return True
except Exception as e:
print(f"Error checking PyTorch: {e}")
return False
def install_torch_sparse():
"""Install torch-sparse with compatible PyTorch version"""
print("Installing torch-sparse...")
# First, install a compatible PyTorch version with CUDA 11.7 (as expected by PyTorch3D)
print("Installing compatible PyTorch version...")
if not run_command("pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117"):
return False
# Check PyTorch installation
print("Checking PyTorch installation...")
check_pytorch_cuda()
# Now install torch-sparse with the compatible version
print("Installing torch-sparse with PyTorch 2.0.1...")
if run_command("pip install torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
print("Successfully installed torch-sparse")
return True
return False
def install_torch_scatter():
"""Install torch-scatter with compatible PyTorch version"""
print("Installing torch-scatter...")
# Install torch-scatter with the compatible PyTorch version
print("Installing torch-scatter with PyTorch 2.0.1...")
if run_command("pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
print("Successfully installed torch-scatter")
return True
return False
def install_nvdiffrast():
"""Install nvdiffrast"""
print("Installing nvdiffrast...")
# Create packages directory if it doesn't exist
packages_dir = Path("packages")
packages_dir.mkdir(exist_ok=True)
# Clone nvdiffrast
if not (packages_dir / "nvdiffrast").exists():
if not run_command("git clone https://github.com/NVlabs/nvdiffrast.git", cwd=packages_dir):
return False
# Install nvdiffrast
nvdiffrast_dir = packages_dir / "nvdiffrast"
if not run_command("pip install .", cwd=nvdiffrast_dir):
return False
return True
def install_pytorch3d():
"""Install PyTorch3D"""
print("Installing PyTorch3D...")
packages_dir = Path("packages")
# Clone PyTorch3D
if not (packages_dir / "pytorch3d").exists():
if not run_command("git clone https://github.com/facebookresearch/pytorch3d.git", cwd=packages_dir):
return False
# Install PyTorch3D with CUDA support
pytorch3d_dir = packages_dir / "pytorch3d"
# Check what CUDA version is available
print("Checking available CUDA version...")
cuda_check = subprocess.run("nvcc --version", shell=True, capture_output=True, text=True)
if cuda_check.returncode == 0:
print(f"CUDA version info: {cuda_check.stdout}")
# Try different CUDA configurations
cuda_configs = [
# Try with CUDA 11.7 first (as expected by PyTorch)
{'CUDA_HOME': '/usr/local/cuda-11.7', 'CUDA_VERSION': '11.7'},
# Try with system CUDA
{'CUDA_HOME': '/usr/local/cuda', 'CUDA_VERSION': '12.3'},
# Try without specifying CUDA_HOME
{}
]
for i, config in enumerate(cuda_configs):
env = os.environ.copy()
env['FORCE_CUDA'] = '1'
if config:
env['CUDA_HOME'] = config['CUDA_HOME']
env['CUDA_VERSION'] = config['CUDA_VERSION']
print(f"Trying PyTorch3D installation with {config}...")
else:
print("Trying PyTorch3D installation with default CUDA...")
if run_command("pip install .", cwd=pytorch3d_dir, env=env):
print(f"Successfully installed PyTorch3D with config {i+1}")
return True
# If all CUDA configs fail, try without FORCE_CUDA
print("Trying PyTorch3D installation without FORCE_CUDA...")
if run_command("pip install .", cwd=pytorch3d_dir):
print("Successfully installed PyTorch3D without CUDA forcing")
return True
return False
def install_fashion_clip():
"""Setup Fashion-CLIP"""
print("Setting up Fashion-CLIP...")
packages_dir = Path("packages")
# Clone Fashion-CLIP if not already present
if not (packages_dir / "fashion-clip").exists():
if not run_command("git clone https://github.com/patrickjohncyh/fashion-clip.git", cwd=packages_dir):
return False
# Install Fashion-CLIP dependencies
fashion_clip_dir = packages_dir / "fashion-clip"
dependencies = ["appdirs", "boto3", "annoy", "validators", "transformers", "datasets"]
for dep in dependencies:
if not run_command(f"pip install {dep}", cwd=fashion_clip_dir):
print(f"Warning: Failed to install {dep}")
return True
def main():
"""Main installation function"""
print("Starting post-installation for Garment3DGen...")
# Install complex dependencies
if not install_torch_sparse():
print("Failed to install torch-sparse")
sys.exit(1)
if not install_torch_scatter():
print("Failed to install torch-scatter")
sys.exit(1)
if not install_nvdiffrast():
print("Failed to install nvdiffrast")
sys.exit(1)
if not install_pytorch3d():
print("Failed to install PyTorch3D")
sys.exit(1)
if not install_fashion_clip():
print("Failed to install Fashion-CLIP")
sys.exit(1)
# Final verification
print("\n=== Final Verification ===")
print("Checking all dependencies...")
try:
import torch
print(f"βœ“ PyTorch {torch.__version__} - CUDA: {torch.cuda.is_available()}")
import torch_sparse
print("βœ“ torch-sparse")
import torch_scatter
print("βœ“ torch-scatter")
import nvdiffrast
print("βœ“ nvdiffrast")
import pytorch3d
print("βœ“ PyTorch3D")
print("\nPost-installation completed successfully!")
print("All dependencies are now available.")
except ImportError as e:
print(f"βœ— Import error: {e}")
sys.exit(1)
except Exception as e:
print(f"βœ— Verification error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()