Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
import time | |
import tempfile | |
import uuid | |
# Access token from environment variable | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
print("Access token loaded.") | |
def generate_video( | |
prompt, | |
negative_prompt, | |
num_frames, | |
fps, | |
width, | |
height, | |
num_inference_steps, | |
guidance_scale, | |
motion_bucket_id, | |
seed, | |
provider, | |
custom_api_key, | |
custom_model, | |
model_search_term, | |
selected_model | |
): | |
"""Generate a video based on the provided parameters""" | |
print(f"Received prompt: {prompt}") | |
print(f"Negative prompt: {negative_prompt}") | |
print(f"Num frames: {num_frames}, FPS: {fps}") | |
print(f"Width: {width}, Height: {height}") | |
print(f"Steps: {num_inference_steps}, Guidance Scale: {guidance_scale}") | |
print(f"Motion Bucket ID: {motion_bucket_id}, Seed: {seed}") | |
print(f"Selected provider: {provider}") | |
print(f"Custom API Key provided: {bool(custom_api_key.strip())}") | |
print(f"Selected model (custom_model): {custom_model}") | |
print(f"Model search term: {model_search_term}") | |
print(f"Selected model from radio: {selected_model}") | |
# Determine which token to use - custom API key if provided, otherwise the ACCESS_TOKEN | |
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN | |
# Log which token source we're using (without printing the actual token) | |
if custom_api_key.strip() != "": | |
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication") | |
else: | |
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication") | |
# Initialize the Inference Client with the provider and appropriate token | |
client = InferenceClient(token=token_to_use, provider=provider) | |
print(f"Hugging Face Inference Client initialized with {provider} provider.") | |
# Convert seed to None if -1 (meaning random) | |
if seed == -1: | |
seed = None | |
else: | |
# Ensure seed is an integer | |
seed = int(seed) | |
# Determine which model to use, prioritizing custom_model if provided | |
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model | |
print(f"Model selected for inference: {model_to_use}") | |
# Create a unique ID for this generation | |
generation_id = uuid.uuid4().hex[:8] | |
print(f"Generation ID: {generation_id}") | |
# Define supported parameters for each provider | |
provider_param_support = { | |
"hf-inference": { | |
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed"], | |
"extra_info": "HF Inference doesn't support 'fps', 'width', 'height', or 'motion_bucket_id' parameters" | |
}, | |
"fal-ai": { | |
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed"], | |
"extra_info": "Fal-AI doesn't support 'fps', 'width', 'height', or 'motion_bucket_id' parameters" | |
}, | |
"novita": { | |
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed", "fps", "width", "height"], | |
"extra_info": "Novita may not support 'motion_bucket_id' parameter" | |
}, | |
"replicate": { | |
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed", "fps", "width", "height"], | |
"extra_info": "Replicate parameters vary by specific model" | |
} | |
} | |
# Get supported parameters for the current provider | |
supported_params = provider_param_support.get(provider, {}).get("supported", []) | |
provider_info = provider_param_support.get(provider, {}).get("extra_info", "No specific information available") | |
print(f"Provider info: {provider_info}") | |
print(f"Supported parameters: {supported_params}") | |
# Create a parameters dictionary with only supported parameters | |
parameters = {} | |
if "negative_prompt" in supported_params: | |
parameters["negative_prompt"] = negative_prompt | |
if "num_frames" in supported_params: | |
parameters["num_frames"] = num_frames | |
if "num_inference_steps" in supported_params: | |
parameters["num_inference_steps"] = num_inference_steps | |
if "guidance_scale" in supported_params: | |
parameters["guidance_scale"] = guidance_scale | |
if "seed" in supported_params and seed is not None: | |
parameters["seed"] = seed | |
if "fps" in supported_params: | |
parameters["fps"] = fps | |
if "width" in supported_params: | |
parameters["width"] = width | |
if "height" in supported_params: | |
parameters["height"] = height | |
if "motion_bucket_id" in supported_params: | |
parameters["motion_bucket_id"] = motion_bucket_id | |
# Now that we have a clean parameter set, handle provider-specific logic | |
print(f"Final parameters for {provider}: {parameters}") | |
# For Replicate provider - uses post method | |
if provider == "replicate": | |
print("Using Replicate provider, using post method...") | |
try: | |
response = client.post( | |
model=model_to_use, | |
input={ | |
"prompt": prompt, | |
**parameters | |
}, | |
) | |
# Replicate typically returns a URL to the generated video | |
if isinstance(response, dict) and "output" in response: | |
video_url = response["output"] | |
print(f"Video generated, URL: {video_url}") | |
return video_url | |
else: | |
return str(response) | |
except Exception as e: | |
print(f"Error during Replicate video generation: {e}") | |
return f"Error: {str(e)}" | |
# For all other providers, use the standard text_to_video method | |
try: | |
print(f"Sending request to {provider} provider with model {model_to_use}.") | |
# Use the text_to_video method of the InferenceClient with only supported parameters | |
video_data = client.text_to_video( | |
prompt=prompt, | |
model=model_to_use, | |
**parameters | |
) | |
# Save the video to a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
temp_file.write(video_data) | |
video_path = temp_file.name | |
temp_file.close() | |
print(f"Video saved to temporary file: {video_path}") | |
return video_path | |
except Exception as e: | |
print(f"Error during video generation: {e}") | |
return f"Error: {str(e)}" | |
# Function to validate provider selection based on BYOK | |
def validate_provider(api_key, provider): | |
# If no custom API key is provided, only "hf-inference" can be used | |
if not api_key.strip() and provider != "hf-inference": | |
return gr.update(value="hf-inference") | |
return gr.update(value=provider) | |
# Define the GRADIO UI | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
# Set a title for the application | |
gr.Markdown("# 🎬 Serverless-VideoGen-Hub") | |
gr.Markdown("Generate videos using Hugging Face Serverless Inference") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Main video output area | |
video_output = gr.Video(label="Generated Video", height=400) | |
# Basic input components | |
prompt_box = gr.Textbox( | |
value="A beautiful sunset over a calm ocean", | |
placeholder="Enter a prompt for your video", | |
label="Prompt", | |
lines=3 | |
) | |
# Generate button | |
generate_button = gr.Button("🎬 Generate Video", variant="primary") | |
# Advanced settings in an accordion | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="What should NOT be in the video", | |
value="poor quality, distortion, blurry, low resolution, grainy", | |
lines=2 | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
minimum=256, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Width" | |
) | |
height = gr.Slider( | |
minimum=256, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Height" | |
) | |
with gr.Row(): | |
num_frames = gr.Slider( | |
minimum=8, | |
maximum=64, | |
value=16, | |
step=1, | |
label="Number of Frames" | |
) | |
fps = gr.Slider( | |
minimum=1, | |
maximum=30, | |
value=8, | |
step=1, | |
label="Frames Per Second" | |
) | |
# Adding the sliders from the right column to the left column | |
with gr.Row(): | |
num_inference_steps = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=25, | |
step=1, | |
label="Inference Steps" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, | |
maximum=20.0, | |
value=7.5, | |
step=0.5, | |
label="Guidance Scale" | |
) | |
with gr.Row(): | |
motion_bucket_id = gr.Slider( | |
minimum=1, | |
maximum=255, | |
value=127, | |
step=1, | |
label="Motion Bucket ID (for SVD models)" | |
) | |
seed = gr.Slider( | |
minimum=-1, | |
maximum=2147483647, | |
value=-1, | |
step=1, | |
label="Seed (-1 for random)" | |
) | |
with gr.Column(): | |
# Provider selection | |
providers_list = [ | |
"hf-inference", # Default Hugging Face Inference | |
"fal-ai", # Fal AI provider | |
"novita", # Novita provider | |
"replicate", # Replicate provider | |
] | |
provider_radio = gr.Radio( | |
choices=providers_list, | |
value="hf-inference", | |
label="Inference Provider", | |
info="Select an inference provider. Note: Requires provider-specific API key except for hf-inference" | |
) | |
# BYOK textbox | |
byok_textbox = gr.Textbox( | |
value="", | |
label="BYOK (Bring Your Own Key)", | |
info="Enter a provider API key here. When empty, only 'hf-inference' provider can be used.", | |
placeholder="Enter your provider API token", | |
type="password" # Hide the API key for security | |
) | |
# Model selection components (moved from left column) | |
model_search_box = gr.Textbox( | |
label="Filter Models", | |
placeholder="Search for a model...", | |
lines=1 | |
) | |
models_list = [ | |
"Lightricks/LTX-Video", | |
"Wan-AI/Wan2.1-T2V-14B", | |
"tencent/HunyuanVideo", | |
"Wan-AI/Wan2.1-T2V-1.3B", | |
"genmo/mochi-1-preview", | |
"THUDM/CogVideoX-5b" | |
] | |
featured_model_radio = gr.Radio( | |
label="Select a model below", | |
choices=models_list, | |
value="Lightricks/LTX-Video", | |
interactive=True | |
) | |
custom_model_box = gr.Textbox( | |
value="", | |
label="Custom Model", | |
info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.", | |
placeholder="damo-vilab/text-to-video-ms-1.7b" | |
) | |
gr.Markdown("[See all available models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-to-video&sort=trending)") | |
# Set up the generation click event | |
generate_button.click( | |
fn=generate_video, | |
inputs=[ | |
prompt_box, | |
negative_prompt, | |
num_frames, | |
fps, | |
width, | |
height, | |
num_inference_steps, | |
guidance_scale, | |
motion_bucket_id, | |
seed, | |
provider_radio, | |
byok_textbox, | |
custom_model_box, | |
model_search_box, | |
featured_model_radio | |
], | |
outputs=video_output | |
) | |
# Connect the model filter to update the radio choices | |
def filter_models(search_term): | |
print(f"Filtering models with search term: {search_term}") | |
filtered = [m for m in models_list if search_term.lower() in m.lower()] | |
print(f"Filtered models: {filtered}") | |
return gr.update(choices=filtered) | |
model_search_box.change( | |
fn=filter_models, | |
inputs=model_search_box, | |
outputs=featured_model_radio | |
) | |
# Connect the featured model radio to update the custom model box | |
def set_custom_model_from_radio(selected): | |
""" | |
This function will get triggered whenever someone picks a model from the 'Featured Models' radio. | |
We will update the Custom Model text box with that selection automatically. | |
""" | |
print(f"Featured model selected: {selected}") | |
return selected | |
featured_model_radio.change( | |
fn=set_custom_model_from_radio, | |
inputs=featured_model_radio, | |
outputs=custom_model_box | |
) | |
# Connect the BYOK textbox to validate provider selection | |
byok_textbox.change( | |
fn=validate_provider, | |
inputs=[byok_textbox, provider_radio], | |
outputs=provider_radio | |
) | |
# Also validate provider when the radio changes to ensure consistency | |
provider_radio.change( | |
fn=validate_provider, | |
inputs=[byok_textbox, provider_radio], | |
outputs=provider_radio | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
print("Launching the demo application.") | |
demo.launch(show_api=True) |