Does this model not even fit an NVIDIA B200?

#14
by Fredtt3 - opened

Hi everyone, I'm trying to run the model and I keep getting OMM errors. This is the code I'm using (I'm using a modal for testing):

import modal
from platformdirs import user_data_dir

image = (
    modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
    .apt_install("git", "curl", "build-essential", "wget", "libgl1", "libglib2.0-0")
    .entrypoint([])
    .run_commands(
        "python -m pip install --upgrade pip",
        "python -m pip install --upgrade setuptools wheel",
    )
    .uv_pip_install(
        "torch==2.8",
        "git+https://github.com/huggingface/diffusers.git",
        "transformers==4.57.3",
        "tokenizers==0.22.1",
        "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.14/flash_attn-2.8.2+cu128torch2.8-cp312-cp312-linux_x86_64.whl",
        "bitsandbytes==0.49.0",
        "accelerate==1.12.0",
        "git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core",
        "git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-pipelines",
        "platformdirs==4.5.1"
    )
    .env({"HF_XET_HIGH_PERFORMANCE": "1"})
)

MINUTES = 60

hf_cache_vol = modal.Volume.from_name("hf-cache", create_if_missing=True)
aquiles_config_vol = modal.Volume.from_name("aquiles-config", create_if_missing=True)
aquiles_video_vol = modal.Volume.from_name("aquiles-video-cache", create_if_missing=True)
data_dir = user_data_dir("aquiles", "Aquiles-Image")

app = modal.App("ltx2")

with image.imports():
    from huggingface_hub import snapshot_download, hf_hub_download
    from ltx_pipelines.ti2vid_one_stage import TI2VidOneStagePipeline
    from ltx_pipelines.utils.media_io import encode_video
    from ltx_pipelines.utils.constants import AUDIO_SAMPLE_RATE
    import torch



@app
	.cls(
    image=image,
    secrets=[modal.Secret.from_name("huggingface-secret")],
    gpu="B200",
    timeout=15 * MINUTES,
    scaledown_window=15 * MINUTES,
    volumes={
        "/root/.cache/huggingface": hf_cache_vol,
        "/root/.local/share": aquiles_config_vol,
    },
)
class LTX2Test:
    

@modal
	.enter()
    def load_pipeline(self):
        #print("download text encoder:\n")
        #snapshot_download("google/gemma-3-12b-it-qat-q4_0-unquantized", local_dir=f"{data_dir}/gemma")
        #print("download model:\n")
        #hf_hub_download("Lightricks/LTX-2", "ltx-2-19b-dev.safetensors", local_dir=f"{data_dir}/ltx")

        self.pipeline = TI2VidOneStagePipeline(checkpoint_path=f"{data_dir}/ltx/ltx-2-19b-dev.safetensors",
        gemma_root=f"{data_dir}/gemma", loras=[])

        print(f" vram_allocated={torch.cuda.memory_allocated() / 1024**3:.2f}GB vram_reserved={torch.cuda.memory_reserved() / 1024**3:.2f}GB vram_total={torch.cuda.get_device_properties().total_memory / 1024**3:.2f}GB")


    

@modal
	.method()
    def generate_video(self, prompt: str):

        video, audio = self.pipeline(
            prompt=prompt,
            negative_prompt="",
            seed=42,
            height=512,
            width=768,
            num_frames=121,
            frame_rate=25.0,
            num_inference_steps=40,
            cfg_guidance_scale=3.0,
            images=""
        )

        print(f" vram_allocated={torch.cuda.memory_allocated() / 1024**3:.2f}GB vram_reserved={torch.cuda.memory_reserved() / 1024**3:.2f}GB vram_total={torch.cuda.get_device_properties().total_memory / 1024**3:.2f}GB")

        output = f"{data_dir}/video/output.mp4"
        
        encode_video(
            video=video,
            fps=25.0,
            audio=audio,
            audio_sample_rate=AUDIO_SAMPLE_RATE,
            output_path=output,
            video_chunks_number=1,
        )

        print(f" vram_allocated={torch.cuda.memory_allocated() / 1024**3:.2f}GB vram_reserved={torch.cuda.memory_reserved() / 1024**3:.2f}GB vram_total={torch.cuda.get_device_properties().total_memory / 1024**3:.2f}GB")

        print(f"Saved video in: {output}")

        return output




@app
	.local_entrypoint()
def entrypoint():
    print("Ltx-2-test")

    prompt = """Intent: wildlife photography print. Background: blurred tropical foliage with soft green bokeh. Foreground: small branch with dew drops. Hero subject: chameleon in profile with vibrant scales and detailed eye texture. Finishing details: photorealistic, crisp scale detail, no logos or trademarks, no watermark. Camera: 100mm macro, shallow depth of field."""

    ltx2 = LTX2Test()

    ltx2.generate_video.remote(prompt=prompt)

Hi everyone, I'm trying to run the model and I keep getting OMM errors. This is the code I'm using (I'm using a modal for testing):

import modal
from platformdirs import user_data_dir

image = (
    modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
    .apt_install("git", "curl", "build-essential", "wget", "libgl1", "libglib2.0-0")
    .entrypoint([])
    .run_commands(
        "python -m pip install --upgrade pip",
        "python -m pip install --upgrade setuptools wheel",
    )
    .uv_pip_install(
        "torch==2.8",
        "git+https://github.com/huggingface/diffusers.git",
        "transformers==4.57.3",
        "tokenizers==0.22.1",
        "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.14/flash_attn-2.8.2+cu128torch2.8-cp312-cp312-linux_x86_64.whl",
        "bitsandbytes==0.49.0",
        "accelerate==1.12.0",
        "git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core",
        "git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-pipelines",
        "platformdirs==4.5.1"
    )
    .env({"HF_XET_HIGH_PERFORMANCE": "1"})
)

MINUTES = 60

hf_cache_vol = modal.Volume.from_name("hf-cache", create_if_missing=True)
aquiles_config_vol = modal.Volume.from_name("aquiles-config", create_if_missing=True)
aquiles_video_vol = modal.Volume.from_name("aquiles-video-cache", create_if_missing=True)
data_dir = user_data_dir("aquiles", "Aquiles-Image")

app = modal.App("ltx2")

with image.imports():
    from huggingface_hub import snapshot_download, hf_hub_download
    from ltx_pipelines.ti2vid_one_stage import TI2VidOneStagePipeline
    from ltx_pipelines.utils.media_io import encode_video
    from ltx_pipelines.utils.constants import AUDIO_SAMPLE_RATE
    import torch



@app
	.cls(
    image=image,
    secrets=[modal.Secret.from_name("huggingface-secret")],
    gpu="B200",
    timeout=15 * MINUTES,
    scaledown_window=15 * MINUTES,
    volumes={
        "/root/.cache/huggingface": hf_cache_vol,
        "/root/.local/share": aquiles_config_vol,
    },
)
class LTX2Test:
    

@modal
	.enter()
    def load_pipeline(self):
        #print("download text encoder:\n")
        #snapshot_download("google/gemma-3-12b-it-qat-q4_0-unquantized", local_dir=f"{data_dir}/gemma")
        #print("download model:\n")
        #hf_hub_download("Lightricks/LTX-2", "ltx-2-19b-dev.safetensors", local_dir=f"{data_dir}/ltx")

        self.pipeline = TI2VidOneStagePipeline(checkpoint_path=f"{data_dir}/ltx/ltx-2-19b-dev.safetensors",
        gemma_root=f"{data_dir}/gemma", loras=[])

        print(f" vram_allocated={torch.cuda.memory_allocated() / 1024**3:.2f}GB vram_reserved={torch.cuda.memory_reserved() / 1024**3:.2f}GB vram_total={torch.cuda.get_device_properties().total_memory / 1024**3:.2f}GB")


    

@modal
	.method()
    def generate_video(self, prompt: str):

        video, audio = self.pipeline(
            prompt=prompt,
            negative_prompt="",
            seed=42,
            height=512,
            width=768,
            num_frames=121,
            frame_rate=25.0,
            num_inference_steps=40,
            cfg_guidance_scale=3.0,
            images=""
        )

        print(f" vram_allocated={torch.cuda.memory_allocated() / 1024**3:.2f}GB vram_reserved={torch.cuda.memory_reserved() / 1024**3:.2f}GB vram_total={torch.cuda.get_device_properties().total_memory / 1024**3:.2f}GB")

        output = f"{data_dir}/video/output.mp4"
        
        encode_video(
            video=video,
            fps=25.0,
            audio=audio,
            audio_sample_rate=AUDIO_SAMPLE_RATE,
            output_path=output,
            video_chunks_number=1,
        )

        print(f" vram_allocated={torch.cuda.memory_allocated() / 1024**3:.2f}GB vram_reserved={torch.cuda.memory_reserved() / 1024**3:.2f}GB vram_total={torch.cuda.get_device_properties().total_memory / 1024**3:.2f}GB")

        print(f"Saved video in: {output}")

        return output




@app
	.local_entrypoint()
def entrypoint():
    print("Ltx-2-test")

    prompt = """Intent: wildlife photography print. Background: blurred tropical foliage with soft green bokeh. Foreground: small branch with dew drops. Hero subject: chameleon in profile with vibrant scales and detailed eye texture. Finishing details: photorealistic, crisp scale detail, no logos or trademarks, no watermark. Camera: 100mm macro, shallow depth of field."""

    ltx2 = LTX2Test()

    ltx2.generate_video.remote(prompt=prompt)

#snapshot_download("google/gemma-3-12b-it-qat-q4_0-unquantized", local_dir=f"{data_dir}/gemma") this TE is a vram monster! Try this 1: https://huggingface.co/unsloth/gemma-3-12b-it-bnb-4bit/tree/main

You are computing the gradient that's why, you need use torch.inference_mode() (or set torch.set_grad_enabled(False) globally)

Thank You!!!! You're the best

Fredtt3 changed discussion status to closed

Sign up or log in to comment