|
|
|
import os |
|
|
|
import click |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
@click.command() |
|
@click.argument('output_dir', required=True) |
|
def download_weights(output_dir): |
|
repo_id = "genmo/mochi-1-preview" |
|
model = "dit.safetensors" |
|
decoder = "decoder.safetensors" |
|
encoder = "encoder.safetensors" |
|
|
|
if not os.path.exists(output_dir): |
|
print(f"Creating output directory: {output_dir}") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
def download_file(repo_id, output_dir, filename, description): |
|
file_path = os.path.join(output_dir, filename) |
|
if not os.path.exists(file_path): |
|
print(f"Downloading mochi {description} to: {file_path}") |
|
snapshot_download( |
|
repo_id=repo_id, |
|
allow_patterns=[f"*{filename}*"], |
|
local_dir=output_dir, |
|
local_dir_use_symlinks=False, |
|
) |
|
else: |
|
print(f"{description} already exists in: {file_path}") |
|
assert os.path.exists(file_path) |
|
|
|
download_file(repo_id, output_dir, decoder, "decoder") |
|
download_file(repo_id, output_dir, encoder, "encoder") |
|
download_file(repo_id, output_dir, model, "model") |
|
|
|
|
|
if __name__ == "__main__": |
|
download_weights() |
|
|