import gradio as gr from huggingface_hub import InferenceClient import os import json import base64 from PIL import Image import io import time import tempfile import uuid # Access token from environment variable ACCESS_TOKEN = os.getenv("HF_TOKEN") print("Access token loaded.") def generate_video( prompt, negative_prompt, num_frames, fps, width, height, num_inference_steps, guidance_scale, motion_bucket_id, seed, provider, custom_api_key, custom_model, model_search_term, selected_model ): """Generate a video based on the provided parameters""" print(f"Received prompt: {prompt}") print(f"Negative prompt: {negative_prompt}") print(f"Num frames: {num_frames}, FPS: {fps}") print(f"Width: {width}, Height: {height}") print(f"Steps: {num_inference_steps}, Guidance Scale: {guidance_scale}") print(f"Motion Bucket ID: {motion_bucket_id}, Seed: {seed}") print(f"Selected provider: {provider}") print(f"Custom API Key provided: {bool(custom_api_key.strip())}") print(f"Selected model (custom_model): {custom_model}") print(f"Model search term: {model_search_term}") print(f"Selected model from radio: {selected_model}") # Determine which token to use - custom API key if provided, otherwise the ACCESS_TOKEN token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN # Log which token source we're using (without printing the actual token) if custom_api_key.strip() != "": print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication") else: print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication") # Initialize the Inference Client with the provider and appropriate token client = InferenceClient(token=token_to_use, provider=provider) print(f"Hugging Face Inference Client initialized with {provider} provider.") # Convert seed to None if -1 (meaning random) if seed == -1: seed = None else: # Ensure seed is an integer seed = int(seed) # Determine which model to use, prioritizing custom_model if provided model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model print(f"Model selected for inference: {model_to_use}") # Create a unique ID for this generation generation_id = uuid.uuid4().hex[:8] print(f"Generation ID: {generation_id}") # Define supported parameters for each provider provider_param_support = { "hf-inference": { "supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed"], "extra_info": "HF Inference doesn't support 'fps', 'width', 'height', or 'motion_bucket_id' parameters" }, "fal-ai": { "supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed"], "extra_info": "Fal-AI doesn't support 'fps', 'width', 'height', or 'motion_bucket_id' parameters" }, "novita": { "supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed", "fps", "width", "height"], "extra_info": "Novita may not support 'motion_bucket_id' parameter" }, "replicate": { "supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed", "fps", "width", "height"], "extra_info": "Replicate parameters vary by specific model" } } # Get supported parameters for the current provider supported_params = provider_param_support.get(provider, {}).get("supported", []) provider_info = provider_param_support.get(provider, {}).get("extra_info", "No specific information available") print(f"Provider info: {provider_info}") print(f"Supported parameters: {supported_params}") # Create a parameters dictionary with only supported parameters parameters = {} if "negative_prompt" in supported_params: parameters["negative_prompt"] = negative_prompt if "num_frames" in supported_params: parameters["num_frames"] = num_frames if "num_inference_steps" in supported_params: parameters["num_inference_steps"] = num_inference_steps if "guidance_scale" in supported_params: parameters["guidance_scale"] = guidance_scale if "seed" in supported_params and seed is not None: parameters["seed"] = seed if "fps" in supported_params: parameters["fps"] = fps if "width" in supported_params: parameters["width"] = width if "height" in supported_params: parameters["height"] = height if "motion_bucket_id" in supported_params: parameters["motion_bucket_id"] = motion_bucket_id # Now that we have a clean parameter set, handle provider-specific logic print(f"Final parameters for {provider}: {parameters}") # For Replicate provider - uses post method if provider == "replicate": print("Using Replicate provider, using post method...") try: response = client.post( model=model_to_use, input={ "prompt": prompt, **parameters }, ) # Replicate typically returns a URL to the generated video if isinstance(response, dict) and "output" in response: video_url = response["output"] print(f"Video generated, URL: {video_url}") return video_url else: return str(response) except Exception as e: print(f"Error during Replicate video generation: {e}") return f"Error: {str(e)}" # For all other providers, use the standard text_to_video method try: print(f"Sending request to {provider} provider with model {model_to_use}.") # Use the text_to_video method of the InferenceClient with only supported parameters video_data = client.text_to_video( prompt=prompt, model=model_to_use, **parameters ) # Save the video to a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") temp_file.write(video_data) video_path = temp_file.name temp_file.close() print(f"Video saved to temporary file: {video_path}") return video_path except Exception as e: print(f"Error during video generation: {e}") return f"Error: {str(e)}" # Function to validate provider selection based on BYOK def validate_provider(api_key, provider): # If no custom API key is provided, only "hf-inference" can be used if not api_key.strip() and provider != "hf-inference": return gr.update(value="hf-inference") return gr.update(value=provider) # Define the GRADIO UI with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: with gr.Row(): with gr.Column(scale=2): # Main video output area video_output = gr.Video(label="Generated Video", height=400) # Basic input components prompt_box = gr.Textbox( value="A beautiful sunset over a calm ocean", placeholder="Enter a prompt for your video", label="Prompt", lines=3 ) # Generate button generate_button = gr.Button("🎬 Generate Video", variant="primary") # Advanced settings in an accordion with gr.Accordion("Advanced Settings", open=False): with gr.Row(): with gr.Column(): negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="What should NOT be in the video", value="poor quality, distortion, blurry, low resolution, grainy", lines=2 ) with gr.Row(): width = gr.Slider( minimum=256, maximum=1024, value=512, step=64, label="Width" ) height = gr.Slider( minimum=256, maximum=1024, value=512, step=64, label="Height" ) with gr.Row(): num_frames = gr.Slider( minimum=8, maximum=64, value=16, step=1, label="Number of Frames" ) fps = gr.Slider( minimum=1, maximum=30, value=8, step=1, label="Frames Per Second" ) # Adding the sliders from the right column to the left column with gr.Row(): num_inference_steps = gr.Slider( minimum=1, maximum=100, value=25, step=1, label="Inference Steps" ) guidance_scale = gr.Slider( minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="Guidance Scale" ) with gr.Row(): motion_bucket_id = gr.Slider( minimum=1, maximum=255, value=127, step=1, label="Motion Bucket ID (for SVD models)" ) seed = gr.Slider( minimum=-1, maximum=2147483647, value=-1, step=1, label="Seed (-1 for random)" ) with gr.Column(): # Provider selection providers_list = [ "hf-inference", # Default Hugging Face Inference "fal-ai", # Fal AI provider "novita", # Novita provider "replicate", # Replicate provider ] provider_radio = gr.Radio( choices=providers_list, value="hf-inference", label="Inference Provider", info="Select an inference provider. Note: Requires provider-specific API key except for hf-inference" ) # BYOK textbox byok_textbox = gr.Textbox( value="", label="BYOK (Bring Your Own Key)", info="Enter a provider API key here. When empty, only 'hf-inference' provider can be used.", placeholder="Enter your provider API token", type="password" # Hide the API key for security ) # Model selection components (moved from left column) model_search_box = gr.Textbox( label="Filter Models", placeholder="Search for a model...", lines=1 ) models_list = [ "Lightricks/LTX-Video", "Wan-AI/Wan2.1-T2V-14B", "tencent/HunyuanVideo", "Wan-AI/Wan2.1-T2V-1.3B", "genmo/mochi-1-preview", "THUDM/CogVideoX-5b" ] featured_model_radio = gr.Radio( label="Select a model below", choices=models_list, value="Lightricks/LTX-Video", interactive=True ) custom_model_box = gr.Textbox( value="", label="Custom Model", info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.", placeholder="damo-vilab/text-to-video-ms-1.7b" ) gr.Markdown("[See all available models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-to-video&sort=trending)") # Set up the generation click event generate_button.click( fn=generate_video, inputs=[ prompt_box, negative_prompt, num_frames, fps, width, height, num_inference_steps, guidance_scale, motion_bucket_id, seed, provider_radio, byok_textbox, custom_model_box, model_search_box, featured_model_radio ], outputs=video_output ) # Connect the model filter to update the radio choices def filter_models(search_term): print(f"Filtering models with search term: {search_term}") filtered = [m for m in models_list if search_term.lower() in m.lower()] print(f"Filtered models: {filtered}") return gr.update(choices=filtered) model_search_box.change( fn=filter_models, inputs=model_search_box, outputs=featured_model_radio ) # Connect the featured model radio to update the custom model box def set_custom_model_from_radio(selected): """ This function will get triggered whenever someone picks a model from the 'Featured Models' radio. We will update the Custom Model text box with that selection automatically. """ print(f"Featured model selected: {selected}") return selected featured_model_radio.change( fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box ) # Connect the BYOK textbox to validate provider selection byok_textbox.change( fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio ) # Also validate provider when the radio changes to ensure consistency provider_radio.change( fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio ) # Launch the app if __name__ == "__main__": print("Launching the demo application.") demo.launch(show_api=True)