Spaces:
Paused
Paused
File size: 14,535 Bytes
1030ba2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# goan.py
# Main application entry point for the goan video generation UI.
# --- Python Standard Library Imports ---
import os
import gradio as gr
import torch
import argparse
import atexit
# --- Local Application Imports ---
# Import managers for different UI sections and shared state
from ui import layout as layout_manager
from ui import metadata as metadata_manager
from ui import queue as queue_manager # Renamed for clarity
from ui import workspace as workspace_manager
from ui import shared_state # Ensure this is used consistently
# --- Diffusers and Helper Imports ---
from diffusers import AutoencoderKLHunyuanVideo
from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer, SiglipImageProcessor, SiglipVisionModel
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
from diffusers_helper.gradio.progress_bar import make_progress_bar_css
# --- Environment Setup ---
os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
# --- Argument Parsing ---
parser = argparse.ArgumentParser(description="goan: FramePack-based Video Generation UI")
parser.add_argument('--share', action='store_true', default=False, help="Enable Gradio sharing link.")
parser.add_argument("--server", type=str, default='127.0.0.1', help="Server name to bind to.")
parser.add_argument("--port", type=int, required=False, help="Port to run the server on.")
parser.add_argument("--inbrowser", action='store_true', default=False, help="Launch in browser automatically.")
# Add the allowed_output_paths argument here
parser.add_argument("--allowed_output_paths", type=str, default="", help="Comma-separated list of additional output folders Gradio is allowed to access. E.g., '~/my_outputs, /mnt/external_drive/vids'")
args = parser.parse_args()
print(f"goan launching with args: {args}")
# --- Model Loading ---
print("Initializing models...")
free_mem_gb = get_cuda_free_memory_gb(gpu)
high_vram = free_mem_gb > 60
print(f'Free VRAM {free_mem_gb} GB, High-VRAM Mode: {high_vram}')
# Populate shared_state.models with loaded model instances
shared_state.models = {
'text_encoder': LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu(),
'text_encoder_2': CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu(),
'tokenizer': LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer'),
'tokenizer_2': CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2'),
'vae': AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu(),
'feature_extractor': SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor'),
'image_encoder': SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu(),
'transformer': HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu(),
'high_vram': high_vram # Renamed key to match worker's expected param
}
print("Models loaded to CPU. Configuring...")
for model_name in ['vae', 'text_encoder', 'text_encoder_2', 'image_encoder', 'transformer']:
shared_state.models[model_name].eval()
if not high_vram:
shared_state.models['vae'].enable_slicing(); shared_state.models['vae'].enable_tiling()
shared_state.models['transformer'].high_quality_fp32_output_for_inference = False
for model_name, dtype in [('transformer', torch.bfloat16), ('vae', torch.float16), ('image_encoder', torch.float16), ('text_encoder', torch.float16), ('text_encoder_2', torch.float16)]:
shared_state.models[model_name].to(dtype=dtype)
for model_obj in shared_state.models.values(): # Iterate over values, not keys
if isinstance(model_obj, torch.nn.Module): model_obj.requires_grad_(False) # Use model_obj here
if not high_vram:
print("Low VRAM mode: Installing DynamicSwap.")
DynamicSwapInstaller.install_model(shared_state.models['transformer'], device=gpu)
DynamicSwapInstaller.install_model(shared_state.models['text_encoder'], device=gpu)
else:
print("High VRAM mode: Moving all models to GPU.")
for model_name in ['text_encoder', 'text_encoder_2', 'image_encoder', 'vae', 'transformer']:
shared_state.models[model_name].to(gpu)
print("Model configuration and placement complete.")
# --- UI Helper Functions ---
def patched_video_is_playable(video_filepath): return True
gr.processing_utils.video_is_playable = patched_video_is_playable
def ui_update_total_segments(total_seconds_ui, latent_window_size_ui):
"""Calculates and formats the total segment count for display in the UI."""
try:
total_segments = int(max(round((total_seconds_ui * 30) / (latent_window_size_ui * 4)), 1))
return f"Calculated Total Segments: {total_segments}"
except (TypeError, ValueError): return "Segments: Invalid input"
# --- UI Creation and Event Wiring ---
print("Creating UI layout...")
# Create the UI by calling the layout manager. This returns the block and a dictionary of components.
ui_components = layout_manager.create_ui()
block = ui_components['block']
# Define lists of components for easier wiring of events
creative_ui_keys = ['prompt_ui', 'n_prompt_ui', 'total_second_length_ui', 'seed_ui', 'preview_frequency_ui', 'segments_to_decode_csv_ui', 'gs_ui', 'gs_schedule_shape_ui', 'gs_final_ui', 'steps_ui', 'cfg_ui', 'rs_ui']
environment_ui_keys = ['use_teacache_ui', 'use_fp32_transformer_output_checkbox_ui', 'gpu_memory_preservation_ui', 'mp4_crf_ui', 'output_folder_ui_ctrl', 'latent_window_size_ui']
full_workspace_ui_keys = creative_ui_keys + environment_ui_keys
creative_ui_components = [ui_components[key] for key in creative_ui_keys]
full_workspace_ui_components = [ui_components[key] for key in full_workspace_ui_keys]
task_defining_ui_inputs = [ui_components['input_image_gallery_ui']] + full_workspace_ui_components
# Define output lists for complex Gradio calls
process_queue_outputs_list = [ui_components[key] for key in ['app_state', 'queue_df_display_ui', 'last_finished_video_ui', 'current_task_preview_image_ui', 'current_task_progress_desc_ui', 'current_task_progress_bar_ui', 'process_queue_button', 'abort_task_button', 'reset_ui_button']]
queue_df_select_outputs_list = [ui_components[key] for key in ['app_state', 'queue_df_display_ui', 'input_image_gallery_ui'] + full_workspace_ui_keys + ['add_task_button', 'cancel_edit_task_button', 'last_finished_video_ui']]
# Wire up all the UI events to their handler functions in the respective managers
with block:
# Workspace Manager Events
ui_components['save_workspace_button'].click(fn=workspace_manager.save_workspace, inputs=full_workspace_ui_components, outputs=None)
ui_components['load_workspace_button'].click(fn=workspace_manager.load_workspace, inputs=None, outputs=full_workspace_ui_components)
ui_components['save_as_default_button'].click(fn=workspace_manager.save_as_default_workspace, inputs=full_workspace_ui_components, outputs=None)
# Metadata Manager Events
ui_components['input_image_gallery_ui'].upload(fn=metadata_manager.handle_image_upload_for_metadata, inputs=[ui_components['input_image_gallery_ui']], outputs=[ui_components['metadata_modal']])
ui_components['confirm_metadata_btn'].click(fn=metadata_manager.apply_and_hide_modal, inputs=[ui_components['input_image_gallery_ui']], outputs=[ui_components['metadata_modal']] + creative_ui_components)
ui_components['cancel_metadata_btn'].click(fn=lambda: gr.update(visible=False), inputs=None, outputs=ui_components['metadata_modal'])
# Queue Manager Events
ui_components['add_task_button'].click(fn=queue_manager.add_or_update_task_in_queue, inputs=[ui_components['app_state']] + task_defining_ui_inputs, outputs=[ui_components['app_state'], ui_components['queue_df_display_ui'], ui_components['add_task_button'], ui_components['cancel_edit_task_button']])
ui_components['process_queue_button'].click(fn=queue_manager.process_task_queue_main_loop, inputs=[ui_components['app_state']], outputs=process_queue_outputs_list)
ui_components['cancel_edit_task_button'].click(fn=queue_manager.cancel_edit_mode_action, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['queue_df_display_ui'], ui_components['add_task_button'], ui_components['cancel_edit_task_button']])
ui_components['abort_task_button'].click(fn=queue_manager.abort_current_task_processing_action, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['abort_task_button']])
ui_components['clear_queue_button_ui'].click(fn=queue_manager.clear_task_queue_action, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['queue_df_display_ui']])
ui_components['save_queue_button_ui'].click(fn=queue_manager.save_queue_to_zip, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['save_queue_zip_b64_output']]).then(fn=None, inputs=[ui_components['save_queue_zip_b64_output']], outputs=None, js="""(b64) => { if(!b64) return; const blob = new Blob([Uint8Array.from(atob(b64), c => c.charCodeAt(0))], {type: 'application/zip'}); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href=url; a.download='goan_queue.zip'; a.click(); URL.revokeObjectURL(url); }""")
ui_components['load_queue_button_ui'].upload(fn=queue_manager.load_queue_from_zip, inputs=[ui_components['app_state'], ui_components['load_queue_button_ui']], outputs=[ui_components['app_state'], ui_components['queue_df_display_ui']])
ui_components['queue_df_display_ui'].select(fn=queue_manager.handle_queue_action_on_select, inputs=[ui_components['app_state']] + task_defining_ui_inputs, outputs=queue_df_select_outputs_list)
# Other UI Event Handlers
ui_components['gs_schedule_shape_ui'].change(fn=lambda choice: gr.update(interactive=(choice != "Off")), inputs=[ui_components['gs_schedule_shape_ui']], outputs=[ui_components['gs_final_ui']])
for ctrl_key in ['total_second_length_ui', 'latent_window_size_ui']:
ui_components[ctrl_key].change(fn=ui_update_total_segments, inputs=[ui_components['total_second_length_ui'], ui_components['latent_window_size_ui']], outputs=[ui_components['total_segments_display_ui']])
refresh_image_path_state = gr.State(None)
# The reset_ui_button's functionality remains the same: save state then reload page
ui_components['reset_ui_button'].click(fn=workspace_manager.save_ui_and_image_for_refresh, inputs=task_defining_ui_inputs, outputs=None).then(fn=None, inputs=None, outputs=None, js="() => { window.location.reload(); }")
# --- Application Startup and Shutdown ---
autoload_outputs = [ui_components[k] for k in ['app_state', 'queue_df_display_ui', 'process_queue_button', 'abort_task_button', 'last_finished_video_ui']]
# This is the crucial block.load chain to ensure re-attachment
(block.load(fn=workspace_manager.load_workspace_on_start, inputs=[], outputs=[refresh_image_path_state] + full_workspace_ui_components)
.then(fn=workspace_manager.load_image_from_path, inputs=[refresh_image_path_state], outputs=[ui_components['input_image_gallery_ui']])
.then(fn=queue_manager.autoload_queue_on_start_action, inputs=[ui_components['app_state']], outputs=autoload_outputs)
.then(lambda s_val: shared_state.global_state_for_autosave.update(s_val), inputs=[ui_components['app_state']], outputs=None)
.then(fn=ui_update_total_segments, inputs=[ui_components['total_second_length_ui'], ui_components['latent_window_size_ui']], outputs=[ui_components['total_segments_display_ui']])
# *** ADDED: Automatic re-attachment of progress UI on load if processing ***
.then(
fn=queue_manager.process_task_queue_main_loop,
inputs=[ui_components['app_state']],
outputs=process_queue_outputs_list, # Use the existing output list
js="""
(app_state_val) => {
// This JS runs after autoload_queue_on_start_action completes.
// If a task is processing, we want to re-invoke the Python generator.
if (app_state_val.queue_state && app_state_val.queue_state.processing) {
console.log("Gradio: Auto-reconnecting to ongoing task output stream.");
// Return a non-null, non-falsey value to trigger the Python function.
return "reconnect_stream";
}
console.log("Gradio: No ongoing task detected for auto-reconnection.");
return null; // Return null to skip calling the Python function
}
"""
)
)
# Register the atexit handler with the global_state_for_autosave
atexit.register(queue_manager.autosave_queue_on_exit_action, shared_state.global_state_for_autosave)
# --- Application Launch ---
if __name__ == "__main__":
print("Starting goan FramePack UI...")
# Determine the initial output folder for allowed_paths based on saved settings or default
initial_output_folder_path = workspace_manager.get_initial_output_folder_from_settings()
expanded_outputs_folder_for_launch = os.path.abspath(initial_output_folder_path)
# Prepare the list of allowed paths for Gradio
final_allowed_paths = [expanded_outputs_folder_for_launch]
if args.allowed_output_paths:
custom_cli_paths = [
os.path.abspath(os.path.expanduser(p.strip()))
for p in args.allowed_output_paths.split(',')
if p.strip()
]
final_allowed_paths.extend(custom_cli_paths)
final_allowed_paths = list(set(final_allowed_paths)) # Remove duplicates
print(f"Gradio allowed paths: {final_allowed_paths}")
block.launch(server_name=args.server, server_port=args.port, share=args.share, inbrowser=args.inbrowser, allowed_paths=final_allowed_paths) |