import os from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError # More specific import path from tqdm import tqdm # For progress bars # --- Configuration --- MODELS_TO_DOWNLOAD = [ { "repo_id": "Skywork/SkyReels-V2-DF-14B-540P", "local_base_path": "F:/Models/SkyReels-V2-DF-14B-540P", # Base path for this model "num_shards": 12, }, { "repo_id": "Skywork/SkyReels-V2-T2V-14B-540P", "local_base_path": "F:/Models/SkyReels-V2-T2V-14B-540P", # Base path for this model "num_shards": 12, }, ] # Common files to download in addition to shards COMMON_FILES = [ "model.safetensors.index.json" # Add other essential files like config.json, tokenizer_config.json, etc., if needed for loading later # For now, we'll stick to the index file as specifically requested for sharded models. # "config.json", # "generation_config.json", # "special_tokens_map.json", # "tokenizer.json", # "tokenizer_config.json", # "vocab.json" ] def download_model_files(repo_id, local_base_path, num_shards): """ Downloads sharded .safetensors model files and common configuration files from a Hugging Face repository. """ print(f"\nDownloading files for repository: {repo_id}") print(f"Target local directory: {local_base_path}") # Create the local directory if it doesn't exist os.makedirs(local_base_path, exist_ok=True) # --- Download common files --- for common_file in COMMON_FILES: print(f"Attempting to download: {common_file}...") try: hf_hub_download( repo_id=repo_id, filename=common_file, local_dir=local_base_path, local_dir_use_symlinks=False, # Download actual file resume_download=True, ) print(f"Successfully downloaded {common_file}") except HfHubHTTPError as e: if e.response.status_code == 404: print(f"Warning: {common_file} not found in repository {repo_id}. Skipping.") else: print(f"Error downloading {common_file}: {e}") except Exception as e: print(f"An unexpected error occurred while downloading {common_file}: {e}") # --- Download sharded model files --- shard_filenames = [] for i in range(1, num_shards + 1): # Filename format: model-00001-of-00012.safetensors shard_filename = f"model-{i:05d}-of-{num_shards:05d}.safetensors" shard_filenames.append(shard_filename) print(f"\nAttempting to download {num_shards} model shards...") for shard_filename in tqdm(shard_filenames, desc=f"Downloading shards for {repo_id}"): try: # print(f"Downloading {shard_filename} to {local_base_path}...") # tqdm provides progress hf_hub_download( repo_id=repo_id, filename=shard_filename, local_dir=local_base_path, local_dir_use_symlinks=False, # Important: download the actual file resume_download=True, # Good for large files ) # print(f"Successfully downloaded {shard_filename}") # tqdm indicates completion except HfHubHTTPError as e: print(f"Error downloading {shard_filename}: {e}") if e.response.status_code == 404: print(f" {shard_filename} not found. Please check repository and shard count.") return False # Stop if a shard download fails except Exception as e: print(f"An unexpected error occurred while downloading {shard_filename}: {e}") return False print(f"All {num_shards} shards for {repo_id} downloaded successfully (or skipped if not found).") return True if __name__ == "__main__": print("Starting model download process...") all_successful = True for model_config in MODELS_TO_DOWNLOAD: success = download_model_files( repo_id=model_config["repo_id"], local_base_path=model_config["local_base_path"], num_shards=model_config["num_shards"] ) if not success: all_successful = False print(f"Failed to download all files for {model_config['repo_id']}.") if all_successful: print("\nAll specified model files downloaded successfully.") else: print("\nSome model files failed to download. Please check the logs.")