Nymbo's picture
Update app.py
e018915 verified
raw
history blame
15.5 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
import time
import tempfile
import uuid
# Access token from environment variable
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
def generate_video(
prompt,
negative_prompt,
num_frames,
fps,
width,
height,
num_inference_steps,
guidance_scale,
motion_bucket_id,
seed,
provider,
custom_api_key,
custom_model,
model_search_term,
selected_model
):
"""Generate a video based on the provided parameters"""
print(f"Received prompt: {prompt}")
print(f"Negative prompt: {negative_prompt}")
print(f"Num frames: {num_frames}, FPS: {fps}")
print(f"Width: {width}, Height: {height}")
print(f"Steps: {num_inference_steps}, Guidance Scale: {guidance_scale}")
print(f"Motion Bucket ID: {motion_bucket_id}, Seed: {seed}")
print(f"Selected provider: {provider}")
print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
print(f"Selected model (custom_model): {custom_model}")
print(f"Model search term: {model_search_term}")
print(f"Selected model from radio: {selected_model}")
# Determine which token to use - custom API key if provided, otherwise the ACCESS_TOKEN
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
# Log which token source we're using (without printing the actual token)
if custom_api_key.strip() != "":
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
else:
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
# Initialize the Inference Client with the provider and appropriate token
client = InferenceClient(token=token_to_use, provider=provider)
print(f"Hugging Face Inference Client initialized with {provider} provider.")
# Convert seed to None if -1 (meaning random)
if seed == -1:
seed = None
else:
# Ensure seed is an integer
seed = int(seed)
# Determine which model to use, prioritizing custom_model if provided
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
print(f"Model selected for inference: {model_to_use}")
# Create a unique ID for this generation
generation_id = uuid.uuid4().hex[:8]
print(f"Generation ID: {generation_id}")
# Define supported parameters for each provider
provider_param_support = {
"hf-inference": {
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed"],
"extra_info": "HF Inference doesn't support 'fps', 'width', 'height', or 'motion_bucket_id' parameters"
},
"fal-ai": {
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed"],
"extra_info": "Fal-AI doesn't support 'fps', 'width', 'height', or 'motion_bucket_id' parameters"
},
"novita": {
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed", "fps", "width", "height"],
"extra_info": "Novita may not support 'motion_bucket_id' parameter"
},
"replicate": {
"supported": ["prompt", "model", "negative_prompt", "num_frames", "num_inference_steps", "guidance_scale", "seed", "fps", "width", "height"],
"extra_info": "Replicate parameters vary by specific model"
}
}
# Get supported parameters for the current provider
supported_params = provider_param_support.get(provider, {}).get("supported", [])
provider_info = provider_param_support.get(provider, {}).get("extra_info", "No specific information available")
print(f"Provider info: {provider_info}")
print(f"Supported parameters: {supported_params}")
# Create a parameters dictionary with only supported parameters
parameters = {}
if "negative_prompt" in supported_params:
parameters["negative_prompt"] = negative_prompt
if "num_frames" in supported_params:
parameters["num_frames"] = num_frames
if "num_inference_steps" in supported_params:
parameters["num_inference_steps"] = num_inference_steps
if "guidance_scale" in supported_params:
parameters["guidance_scale"] = guidance_scale
if "seed" in supported_params and seed is not None:
parameters["seed"] = seed
if "fps" in supported_params:
parameters["fps"] = fps
if "width" in supported_params:
parameters["width"] = width
if "height" in supported_params:
parameters["height"] = height
if "motion_bucket_id" in supported_params:
parameters["motion_bucket_id"] = motion_bucket_id
# Now that we have a clean parameter set, handle provider-specific logic
print(f"Final parameters for {provider}: {parameters}")
# For Replicate provider - uses post method
if provider == "replicate":
print("Using Replicate provider, using post method...")
try:
response = client.post(
model=model_to_use,
input={
"prompt": prompt,
**parameters
},
)
# Replicate typically returns a URL to the generated video
if isinstance(response, dict) and "output" in response:
video_url = response["output"]
print(f"Video generated, URL: {video_url}")
return video_url
else:
return str(response)
except Exception as e:
print(f"Error during Replicate video generation: {e}")
return f"Error: {str(e)}"
# For all other providers, use the standard text_to_video method
try:
print(f"Sending request to {provider} provider with model {model_to_use}.")
# Use the text_to_video method of the InferenceClient with only supported parameters
video_data = client.text_to_video(
prompt=prompt,
model=model_to_use,
**parameters
)
# Save the video to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_file.write(video_data)
video_path = temp_file.name
temp_file.close()
print(f"Video saved to temporary file: {video_path}")
return video_path
except Exception as e:
print(f"Error during video generation: {e}")
return f"Error: {str(e)}"
# Function to validate provider selection based on BYOK
def validate_provider(api_key, provider):
# If no custom API key is provided, only "hf-inference" can be used
if not api_key.strip() and provider != "hf-inference":
return gr.update(value="hf-inference")
return gr.update(value=provider)
# Define the GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
# Set a title for the application
gr.Markdown("# 🎬 Serverless-VideoGen-Hub")
gr.Markdown("Generate videos using Hugging Face Serverless Inference")
with gr.Row():
with gr.Column(scale=2):
# Main video output area
video_output = gr.Video(label="Generated Video", height=400)
# Basic input components
prompt_box = gr.Textbox(
value="A beautiful sunset over a calm ocean",
placeholder="Enter a prompt for your video",
label="Prompt",
lines=3
)
# Generate button
generate_button = gr.Button("🎬 Generate Video", variant="primary")
# Advanced settings in an accordion
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
with gr.Column():
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="What should NOT be in the video",
value="poor quality, distortion, blurry, low resolution, grainy",
lines=2
)
with gr.Row():
width = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Width"
)
height = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Height"
)
with gr.Row():
num_frames = gr.Slider(
minimum=8,
maximum=64,
value=16,
step=1,
label="Number of Frames"
)
fps = gr.Slider(
minimum=1,
maximum=30,
value=8,
step=1,
label="Frames Per Second"
)
# Adding the sliders from the right column to the left column
with gr.Row():
num_inference_steps = gr.Slider(
minimum=1,
maximum=100,
value=25,
step=1,
label="Inference Steps"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.5,
label="Guidance Scale"
)
with gr.Row():
motion_bucket_id = gr.Slider(
minimum=1,
maximum=255,
value=127,
step=1,
label="Motion Bucket ID (for SVD models)"
)
seed = gr.Slider(
minimum=-1,
maximum=2147483647,
value=-1,
step=1,
label="Seed (-1 for random)"
)
with gr.Column():
# Provider selection
providers_list = [
"hf-inference", # Default Hugging Face Inference
"fal-ai", # Fal AI provider
"novita", # Novita provider
"replicate", # Replicate provider
]
provider_radio = gr.Radio(
choices=providers_list,
value="hf-inference",
label="Inference Provider",
info="Select an inference provider. Note: Requires provider-specific API key except for hf-inference"
)
# BYOK textbox
byok_textbox = gr.Textbox(
value="",
label="BYOK (Bring Your Own Key)",
info="Enter a provider API key here. When empty, only 'hf-inference' provider can be used.",
placeholder="Enter your provider API token",
type="password" # Hide the API key for security
)
# Model selection components (moved from left column)
model_search_box = gr.Textbox(
label="Filter Models",
placeholder="Search for a model...",
lines=1
)
models_list = [
"Lightricks/LTX-Video",
"Wan-AI/Wan2.1-T2V-14B",
"tencent/HunyuanVideo",
"Wan-AI/Wan2.1-T2V-1.3B",
"genmo/mochi-1-preview",
"THUDM/CogVideoX-5b"
]
featured_model_radio = gr.Radio(
label="Select a model below",
choices=models_list,
value="Lightricks/LTX-Video",
interactive=True
)
custom_model_box = gr.Textbox(
value="",
label="Custom Model",
info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.",
placeholder="damo-vilab/text-to-video-ms-1.7b"
)
gr.Markdown("[See all available models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-to-video&sort=trending)")
# Set up the generation click event
generate_button.click(
fn=generate_video,
inputs=[
prompt_box,
negative_prompt,
num_frames,
fps,
width,
height,
num_inference_steps,
guidance_scale,
motion_bucket_id,
seed,
provider_radio,
byok_textbox,
custom_model_box,
model_search_box,
featured_model_radio
],
outputs=video_output
)
# Connect the model filter to update the radio choices
def filter_models(search_term):
print(f"Filtering models with search term: {search_term}")
filtered = [m for m in models_list if search_term.lower() in m.lower()]
print(f"Filtered models: {filtered}")
return gr.update(choices=filtered)
model_search_box.change(
fn=filter_models,
inputs=model_search_box,
outputs=featured_model_radio
)
# Connect the featured model radio to update the custom model box
def set_custom_model_from_radio(selected):
"""
This function will get triggered whenever someone picks a model from the 'Featured Models' radio.
We will update the Custom Model text box with that selection automatically.
"""
print(f"Featured model selected: {selected}")
return selected
featured_model_radio.change(
fn=set_custom_model_from_radio,
inputs=featured_model_radio,
outputs=custom_model_box
)
# Connect the BYOK textbox to validate provider selection
byok_textbox.change(
fn=validate_provider,
inputs=[byok_textbox, provider_radio],
outputs=provider_radio
)
# Also validate provider when the radio changes to ensure consistency
provider_radio.change(
fn=validate_provider,
inputs=[byok_textbox, provider_radio],
outputs=provider_radio
)
# Launch the app
if __name__ == "__main__":
print("Launching the demo application.")
demo.launch(show_api=True)