Spaces:
Paused
Paused
| import os | |
| import torch | |
| import torch.nn as nn | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| from diffsynth import ModelManager, WanVideoReCamMasterPipeline | |
| logger = logging.getLogger(__name__) | |
| # Get model storage path from environment variable or use default | |
| MODELS_ROOT_DIR = os.environ.get("RECAMMASTER_MODELS_DIR", "/data/models") | |
| logger.info(f"Using models root directory: {MODELS_ROOT_DIR}") | |
| # Define model repositories and files | |
| WAN21_REPO_ID = "Wan-AI/Wan2.1-T2V-1.3B" | |
| WAN21_LOCAL_DIR = f"{MODELS_ROOT_DIR}/Wan-AI/Wan2.1-T2V-1.3B" | |
| WAN21_FILES = [ | |
| "diffusion_pytorch_model.safetensors", | |
| "models_t5_umt5-xxl-enc-bf16.pth", | |
| "Wan2.1_VAE.pth" | |
| ] | |
| # Define tokenizer files to download | |
| UMT5_XXL_TOKENIZER_FILES = [ | |
| "google/umt5-xxl/special_tokens_map.json", | |
| "google/umt5-xxl/spiece.model", | |
| "google/umt5-xxl/tokenizer.json", | |
| "google/umt5-xxl/tokenizer_config.json" | |
| ] | |
| RECAMMASTER_REPO_ID = "KwaiVGI/ReCamMaster-Wan2.1" | |
| RECAMMASTER_CHECKPOINT_FILE = "step20000.ckpt" | |
| RECAMMASTER_LOCAL_DIR = f"{MODELS_ROOT_DIR}/ReCamMaster/checkpoints" | |
| class ModelLoader: | |
| def __init__(self): | |
| self.model_manager = None | |
| self.pipe = None | |
| self.is_loaded = False | |
| def download_umt5_xxl_tokenizer(self, progress_callback=None): | |
| """Download UMT5-XXL tokenizer files from HuggingFace""" | |
| total_files = len(UMT5_XXL_TOKENIZER_FILES) | |
| downloaded_paths = [] | |
| for i, file_path in enumerate(UMT5_XXL_TOKENIZER_FILES): | |
| local_dir = f"{WAN21_LOCAL_DIR}/{os.path.dirname(file_path)}" | |
| filename = os.path.basename(file_path) | |
| full_local_path = f"{WAN21_LOCAL_DIR}/{file_path}" | |
| # Update progress | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Checking tokenizer file {i+1}/{total_files}: {filename}") | |
| # Check if already exists | |
| if os.path.exists(full_local_path): | |
| logger.info(f"✓ Tokenizer file {filename} already exists at {full_local_path}") | |
| downloaded_paths.append(full_local_path) | |
| continue | |
| # Create directory if it doesn't exist | |
| os.makedirs(local_dir, exist_ok=True) | |
| # Download the file | |
| logger.info(f"Downloading tokenizer file {filename} from {WAN21_REPO_ID}/{file_path}...") | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Downloading tokenizer file {i+1}/{total_files}: {filename}") | |
| try: | |
| # Download using huggingface_hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=WAN21_REPO_ID, | |
| filename=file_path, | |
| local_dir=WAN21_LOCAL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"✓ Successfully downloaded tokenizer file {filename} to {downloaded_path}!") | |
| downloaded_paths.append(downloaded_path) | |
| except Exception as e: | |
| logger.error(f"✗ Error downloading tokenizer file {filename}: {e}") | |
| raise | |
| if progress_callback: | |
| progress_callback(1.0, desc=f"All tokenizer files downloaded successfully!") | |
| return downloaded_paths | |
| def download_wan21_models(self, progress_callback=None): | |
| """Download Wan2.1 model files from HuggingFace""" | |
| total_files = len(WAN21_FILES) | |
| downloaded_paths = [] | |
| # Create directory if it doesn't exist | |
| Path(WAN21_LOCAL_DIR).mkdir(parents=True, exist_ok=True) | |
| for i, filename in enumerate(WAN21_FILES): | |
| local_path = Path(WAN21_LOCAL_DIR) / filename | |
| # Update progress | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Checking Wan2.1 file {i+1}/{total_files}: {filename}") | |
| # Check if already exists | |
| if local_path.exists(): | |
| logger.info(f"✓ {filename} already exists at {local_path}") | |
| downloaded_paths.append(str(local_path)) | |
| continue | |
| # Download the file | |
| logger.info(f"Downloading {filename} from {WAN21_REPO_ID}...") | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Downloading Wan2.1 file {i+1}/{total_files}: {filename}") | |
| try: | |
| # Download using huggingface_hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=WAN21_REPO_ID, | |
| filename=filename, | |
| local_dir=WAN21_LOCAL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"✓ Successfully downloaded {filename} to {downloaded_path}!") | |
| downloaded_paths.append(downloaded_path) | |
| except Exception as e: | |
| logger.error(f"✗ Error downloading {filename}: {e}") | |
| raise | |
| if progress_callback: | |
| progress_callback(1.0, desc=f"All Wan2.1 models downloaded successfully!") | |
| return downloaded_paths | |
| def download_recammaster_checkpoint(self, progress_callback=None): | |
| """Download ReCamMaster checkpoint from HuggingFace using huggingface_hub""" | |
| checkpoint_path = Path(RECAMMASTER_LOCAL_DIR) / RECAMMASTER_CHECKPOINT_FILE | |
| # Check if already exists | |
| if checkpoint_path.exists(): | |
| logger.info(f"✓ ReCamMaster checkpoint already exists at {checkpoint_path}") | |
| return checkpoint_path | |
| # Create directory if it doesn't exist | |
| Path(RECAMMASTER_LOCAL_DIR).mkdir(parents=True, exist_ok=True) | |
| # Download the checkpoint | |
| logger.info("Downloading ReCamMaster checkpoint from HuggingFace...") | |
| logger.info(f"Repository: {RECAMMASTER_REPO_ID}") | |
| logger.info(f"File: {RECAMMASTER_CHECKPOINT_FILE}") | |
| logger.info(f"Destination: {checkpoint_path}") | |
| if progress_callback: | |
| progress_callback(0.0, desc=f"Downloading ReCamMaster checkpoint...") | |
| try: | |
| # Download using huggingface_hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=RECAMMASTER_REPO_ID, | |
| filename=RECAMMASTER_CHECKPOINT_FILE, | |
| local_dir=RECAMMASTER_LOCAL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"✓ Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!") | |
| if progress_callback: | |
| progress_callback(1.0, desc=f"ReCamMaster checkpoint downloaded successfully!") | |
| return downloaded_path | |
| except Exception as e: | |
| logger.error(f"✗ Error downloading checkpoint: {e}") | |
| raise | |
| def create_symlink_for_tokenizer(self): | |
| """Create symlink for google/umt5-xxl to handle potential path issues""" | |
| try: | |
| google_dir = f"{MODELS_ROOT_DIR}/google" | |
| if not os.path.exists(google_dir): | |
| os.makedirs(google_dir, exist_ok=True) | |
| umt5_xxl_symlink = f"{google_dir}/umt5-xxl" | |
| umt5_xxl_source = f"{WAN21_LOCAL_DIR}/google/umt5-xxl" | |
| # Create a symlink if it doesn't exist | |
| if not os.path.exists(umt5_xxl_symlink) and os.path.exists(umt5_xxl_source): | |
| if os.name == 'nt': # Windows | |
| import ctypes | |
| kdll = ctypes.windll.LoadLibrary("kernel32.dll") | |
| kdll.CreateSymbolicLinkA(umt5_xxl_symlink.encode(), umt5_xxl_source.encode(), 1) | |
| else: # Unix/Linux | |
| os.symlink(umt5_xxl_source, umt5_xxl_symlink) | |
| logger.info(f"Created symlink from {umt5_xxl_source} to {umt5_xxl_symlink}") | |
| except Exception as e: | |
| logger.warning(f"Could not create symlink for google/umt5-xxl: {str(e)}") | |
| # This is a warning, not an error, as we'll try to proceed anyway | |
| def load_models(self, progress_callback=None): | |
| """Load the ReCamMaster models""" | |
| if self.is_loaded: | |
| return "Models already loaded!" | |
| try: | |
| logger.info("Starting model loading...") | |
| # Import test data creator | |
| from test_data import create_test_data_structure | |
| # First create the test data structure | |
| if progress_callback: | |
| progress_callback(0.05, desc="Setting up test data structure...") | |
| try: | |
| create_test_data_structure(progress_callback) | |
| except Exception as e: | |
| error_msg = f"Error creating test data structure: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Second, ensure the checkpoint is downloaded | |
| if progress_callback: | |
| progress_callback(0.1, desc="Checking for ReCamMaster checkpoint...") | |
| try: | |
| ckpt_path = self.download_recammaster_checkpoint(progress_callback) | |
| logger.info(f"Using checkpoint at {ckpt_path}") | |
| except Exception as e: | |
| error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Third, download Wan2.1 models if needed | |
| if progress_callback: | |
| progress_callback(0.2, desc="Checking for Wan2.1 models...") | |
| try: | |
| wan21_paths = self.download_wan21_models(progress_callback) | |
| logger.info(f"Using Wan2.1 models: {wan21_paths}") | |
| except Exception as e: | |
| error_msg = f"Error downloading Wan2.1 models: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Fourth, download UMT5-XXL tokenizer files | |
| if progress_callback: | |
| progress_callback(0.3, desc="Checking for UMT5-XXL tokenizer files...") | |
| try: | |
| tokenizer_paths = self.download_umt5_xxl_tokenizer(progress_callback) | |
| logger.info(f"Using UMT5-XXL tokenizer files: {tokenizer_paths}") | |
| except Exception as e: | |
| error_msg = f"Error downloading UMT5-XXL tokenizer files: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Now, load the models | |
| if progress_callback: | |
| progress_callback(0.4, desc="Loading model manager...") | |
| # Create symlink for tokenizer | |
| self.create_symlink_for_tokenizer() | |
| # Load Wan2.1 pre-trained models | |
| self.model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") | |
| if progress_callback: | |
| progress_callback(0.5, desc="Loading Wan2.1 models...") | |
| # Build full paths for the model files | |
| model_files = [f"{WAN21_LOCAL_DIR}/{filename}" for filename in WAN21_FILES] | |
| for model_file in model_files: | |
| logger.info(f"Loading model from: {model_file}") | |
| if not os.path.exists(model_file): | |
| error_msg = f"Error: Model file not found: {model_file}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Set environment variable for transformers to find the tokenizer | |
| os.environ["TRANSFORMERS_CACHE"] = MODELS_ROOT_DIR | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism warning | |
| self.model_manager.load_models(model_files) | |
| if progress_callback: | |
| progress_callback(0.7, desc="Creating pipeline...") | |
| self.pipe = WanVideoReCamMasterPipeline.from_model_manager(self.model_manager, device="cuda") | |
| if progress_callback: | |
| progress_callback(0.8, desc="Initializing ReCamMaster modules...") | |
| # Initialize additional modules introduced in ReCamMaster | |
| dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] | |
| for block in self.pipe.dit.blocks: | |
| block.cam_encoder = nn.Linear(12, dim) | |
| block.projector = nn.Linear(dim, dim) | |
| block.cam_encoder.weight.data.zero_() | |
| block.cam_encoder.bias.data.zero_() | |
| block.projector.weight = nn.Parameter(torch.eye(dim)) | |
| block.projector.bias = nn.Parameter(torch.zeros(dim)) | |
| if progress_callback: | |
| progress_callback(0.9, desc="Loading ReCamMaster checkpoint...") | |
| # Load ReCamMaster checkpoint | |
| if not os.path.exists(ckpt_path): | |
| error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt." | |
| logger.error(error_msg) | |
| return error_msg | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| self.pipe.dit.load_state_dict(state_dict, strict=True) | |
| self.pipe.to("cuda") | |
| self.pipe.to(dtype=torch.bfloat16) | |
| self.is_loaded = True | |
| if progress_callback: | |
| progress_callback(1.0, desc="Models loaded successfully!") | |
| logger.info("Models loaded successfully!") | |
| return "Models loaded successfully!" | |
| except Exception as e: | |
| logger.error(f"Error loading models: {str(e)}") | |
| return f"Error loading models: {str(e)}" |