Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from os.path import join, realpath | |
| import sys | |
| try: | |
| from typing import override | |
| except ImportError: | |
| def override(f): | |
| return f | |
| from loguru import logger as log | |
| import numpy as np | |
| from multi_gpu import MultiGPUInferenceAR | |
| from server_base import ROOT_DIR | |
| from server_cosmos_base import CosmosBaseModel | |
| COSMOS_PREDICT1_ROOT = ROOT_DIR | |
| TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost' | |
| TORCHRUN_DEFAULT_MASTER_PORT = 12355 | |
| def add_cosmos_venv_to_path(): | |
| version_string = f"python{sys.version_info.major}.{sys.version_info.minor}" | |
| extras = [ | |
| COSMOS_PREDICT1_ROOT, | |
| join(COSMOS_PREDICT1_ROOT, "cosmos_predict1"), | |
| ] | |
| for e in extras: | |
| if e not in sys.path: | |
| sys.path.append(e) | |
| class CosmosModel(CosmosBaseModel): | |
| """ | |
| Serves frames generated on-the-fly by the Cosmos generative model. | |
| Intended for use with the Cosmos-Predict-1 based Gen3C model. | |
| """ | |
| def __init__(self, gpu_count: int = 0, **kwargs): | |
| add_cosmos_venv_to_path() | |
| if not os.environ.get("HF_HOME"): | |
| os.environ["HF_HOME"] = join(COSMOS_PREDICT1_ROOT, "huggingface_home") | |
| super().__init__(**kwargs) | |
| assert os.path.isdir(join(COSMOS_PREDICT1_ROOT, "cosmos_predict1")), \ | |
| f"Could not find Cosmos (cosmos_predict1) directory at: {COSMOS_PREDICT1_ROOT}" | |
| from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel, create_parser | |
| import torch | |
| if gpu_count == 0: | |
| # Use as many GPUs for inference as are available on this machine. | |
| gpu_count = torch.cuda.device_count() | |
| # Note: we use the argparse-based interface so that all defaults are preserved. | |
| parser = create_parser() | |
| common_args = [ | |
| "--checkpoint_dir", self.checkpoint_path or join(COSMOS_PREDICT1_ROOT, "checkpoints"), | |
| "--video_save_name=", # Empty string | |
| "--video_save_folder", join(COSMOS_PREDICT1_ROOT, "outputs"), | |
| "--trajectory", "none", | |
| "--prompt=", # Empty string | |
| "--negative_prompt=", # Empty string | |
| "--offload_prompt_upsampler", | |
| "--disable_prompt_upsampler", | |
| "--disable_guardrail", | |
| "--num_gpus", str(gpu_count), | |
| "--guidance", "1.0", | |
| "--num_video_frames", "121", | |
| "--foreground_masking", | |
| ] | |
| args = parser.parse_args(common_args) | |
| if gpu_count == 1: | |
| self.model = Gen3cPersistentModel(args) | |
| else: | |
| log.info(f"Loading Cosmos-Predict1 inference model on {gpu_count} GPUs.") | |
| self.model = MultiGPUInferenceAR(gpu_count, cosmos_variant="predict1", args=args) | |
| # Since the model may require overlap of inference batches, | |
| # we save previous inference poses so that we can provide any number of | |
| # previous camera poses when starting the next inference batch. | |
| # TODO: ensure some kind of ordering? | |
| self.pose_history_w2c: list[np.array] = [] | |
| self.intrinsics_history: list[np.array] = [] | |
| self.default_focal_length = (338.29, 338.29) | |
| self.default_principal_point = (0.5, 0.5) | |
| self.aabb_min = np.array([-16, -16, -16]) | |
| self.aabb_max = np.array([16, 16, 16]) | |
| def inference_resolution(self) -> list[tuple[int, int]] | None: | |
| """The supported inference resolutions, or None if any resolution is supported.""" | |
| return [(1280, 704),] | |
| def max_frames_per_request(self) -> int: | |
| # Not actually tested, but anyway we can expect autoregressive | |
| # generation to go wrong earlier than this. | |
| return self.model.frames_per_batch * 100 | |
| def default_framerate(self) -> float: | |
| return 24.0 | |
| def cleanup(self): | |
| if isinstance(self.model, MultiGPUInferenceAR): | |
| self.model.cleanup() | |
| def metadata(self) -> dict: | |
| result = super().metadata() | |
| result["model_name"] = "CosmosModel" | |
| return result | |
| if __name__ == "__main__": | |
| model = CosmosModel() | |