rahul7star commited on
Commit
2830b98
·
verified ·
1 Parent(s): 1030ba2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # goan.py
2
+ # Main application entry point for the goan video generation UI.
3
+
4
+ # --- Python Standard Library Imports ---
5
+ import os
6
+ import gradio as gr
7
+ import torch
8
+ import argparse
9
+ import atexit
10
+ import spaces
11
+ # --- Local Application Imports ---
12
+ # Import managers for different UI sections and shared state
13
+ from ui import layout as layout_manager
14
+ from ui import metadata as metadata_manager
15
+ from ui import queue as queue_manager # Renamed for clarity
16
+ from ui import workspace as workspace_manager
17
+ from ui import shared_state # Ensure this is used consistently
18
+
19
+ # --- Diffusers and Helper Imports ---
20
+ from diffusers import AutoencoderKLHunyuanVideo
21
+ from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer, SiglipImageProcessor, SiglipVisionModel
22
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
23
+ from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
24
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_css
25
+
26
+ # --- Environment Setup ---
27
+ os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
28
+
29
+ # --- Argument Parsing ---
30
+ parser = argparse.ArgumentParser(description="goan: FramePack-based Video Generation UI")
31
+ parser.add_argument('--share', action='store_true', default=False, help="Enable Gradio sharing link.")
32
+ parser.add_argument("--server", type=str, default='127.0.0.1', help="Server name to bind to.")
33
+ parser.add_argument("--port", type=int, required=False, help="Port to run the server on.")
34
+ parser.add_argument("--inbrowser", action='store_true', default=False, help="Launch in browser automatically.")
35
+ # Add the allowed_output_paths argument here
36
+ parser.add_argument("--allowed_output_paths", type=str, default="", help="Comma-separated list of additional output folders Gradio is allowed to access. E.g., '~/my_outputs, /mnt/external_drive/vids'")
37
+ args = parser.parse_args()
38
+ print(f"goan launching with args: {args}")
39
+
40
+ # --- Model Loading ---
41
+ print("Initializing models...")
42
+ free_mem_gb = get_cuda_free_memory_gb(gpu)
43
+ high_vram = free_mem_gb > 60
44
+ print(f'Free VRAM {free_mem_gb} GB, High-VRAM Mode: {high_vram}')
45
+
46
+ # Populate shared_state.models with loaded model instances
47
+ shared_state.models = {
48
+ 'text_encoder': LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu(),
49
+ 'text_encoder_2': CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu(),
50
+ 'tokenizer': LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer'),
51
+ 'tokenizer_2': CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2'),
52
+ 'vae': AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu(),
53
+ 'feature_extractor': SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor'),
54
+ 'image_encoder': SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu(),
55
+ 'transformer': HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu(),
56
+ 'high_vram': high_vram # Renamed key to match worker's expected param
57
+ }
58
+ print("Models loaded to CPU. Configuring...")
59
+ for model_name in ['vae', 'text_encoder', 'text_encoder_2', 'image_encoder', 'transformer']:
60
+ shared_state.models[model_name].eval()
61
+ if not high_vram:
62
+ shared_state.models['vae'].enable_slicing(); shared_state.models['vae'].enable_tiling()
63
+ shared_state.models['transformer'].high_quality_fp32_output_for_inference = False
64
+ for model_name, dtype in [('transformer', torch.bfloat16), ('vae', torch.float16), ('image_encoder', torch.float16), ('text_encoder', torch.float16), ('text_encoder_2', torch.float16)]:
65
+ shared_state.models[model_name].to(dtype=dtype)
66
+ for model_obj in shared_state.models.values(): # Iterate over values, not keys
67
+ if isinstance(model_obj, torch.nn.Module): model_obj.requires_grad_(False) # Use model_obj here
68
+ if not high_vram:
69
+ print("Low VRAM mode: Installing DynamicSwap.")
70
+ DynamicSwapInstaller.install_model(shared_state.models['transformer'], device=gpu)
71
+ DynamicSwapInstaller.install_model(shared_state.models['text_encoder'], device=gpu)
72
+ else:
73
+ print("High VRAM mode: Moving all models to GPU.")
74
+ for model_name in ['text_encoder', 'text_encoder_2', 'image_encoder', 'vae', 'transformer']:
75
+ shared_state.models[model_name].to(gpu)
76
+ print("Model configuration and placement complete.")
77
+
78
+
79
+ # --- UI Helper Functions ---
80
+ def patched_video_is_playable(video_filepath): return True
81
+ gr.processing_utils.video_is_playable = patched_video_is_playable
82
+
83
+ def ui_update_total_segments(total_seconds_ui, latent_window_size_ui):
84
+ """Calculates and formats the total segment count for display in the UI."""
85
+ try:
86
+ total_segments = int(max(round((total_seconds_ui * 30) / (latent_window_size_ui * 4)), 1))
87
+ return f"Calculated Total Segments: {total_segments}"
88
+ except (TypeError, ValueError): return "Segments: Invalid input"
89
+
90
+
91
+ # --- UI Creation and Event Wiring ---
92
+ print("Creating UI layout...")
93
+ # Create the UI by calling the layout manager. This returns the block and a dictionary of components.
94
+ ui_components = layout_manager.create_ui()
95
+ block = ui_components['block']
96
+
97
+ # Define lists of components for easier wiring of events
98
+ creative_ui_keys = ['prompt_ui', 'n_prompt_ui', 'total_second_length_ui', 'seed_ui', 'preview_frequency_ui', 'segments_to_decode_csv_ui', 'gs_ui', 'gs_schedule_shape_ui', 'gs_final_ui', 'steps_ui', 'cfg_ui', 'rs_ui']
99
+ environment_ui_keys = ['use_teacache_ui', 'use_fp32_transformer_output_checkbox_ui', 'gpu_memory_preservation_ui', 'mp4_crf_ui', 'output_folder_ui_ctrl', 'latent_window_size_ui']
100
+ full_workspace_ui_keys = creative_ui_keys + environment_ui_keys
101
+
102
+ creative_ui_components = [ui_components[key] for key in creative_ui_keys]
103
+ full_workspace_ui_components = [ui_components[key] for key in full_workspace_ui_keys]
104
+ task_defining_ui_inputs = [ui_components['input_image_gallery_ui']] + full_workspace_ui_components
105
+
106
+ # Define output lists for complex Gradio calls
107
+ process_queue_outputs_list = [ui_components[key] for key in ['app_state', 'queue_df_display_ui', 'last_finished_video_ui', 'current_task_preview_image_ui', 'current_task_progress_desc_ui', 'current_task_progress_bar_ui', 'process_queue_button', 'abort_task_button', 'reset_ui_button']]
108
+ queue_df_select_outputs_list = [ui_components[key] for key in ['app_state', 'queue_df_display_ui', 'input_image_gallery_ui'] + full_workspace_ui_keys + ['add_task_button', 'cancel_edit_task_button', 'last_finished_video_ui']]
109
+
110
+ # Wire up all the UI events to their handler functions in the respective managers
111
+ with block:
112
+ # Workspace Manager Events
113
+ ui_components['save_workspace_button'].click(fn=workspace_manager.save_workspace, inputs=full_workspace_ui_components, outputs=None)
114
+ ui_components['load_workspace_button'].click(fn=workspace_manager.load_workspace, inputs=None, outputs=full_workspace_ui_components)
115
+ ui_components['save_as_default_button'].click(fn=workspace_manager.save_as_default_workspace, inputs=full_workspace_ui_components, outputs=None)
116
+
117
+ # Metadata Manager Events
118
+ ui_components['input_image_gallery_ui'].upload(fn=metadata_manager.handle_image_upload_for_metadata, inputs=[ui_components['input_image_gallery_ui']], outputs=[ui_components['metadata_modal']])
119
+ ui_components['confirm_metadata_btn'].click(fn=metadata_manager.apply_and_hide_modal, inputs=[ui_components['input_image_gallery_ui']], outputs=[ui_components['metadata_modal']] + creative_ui_components)
120
+ ui_components['cancel_metadata_btn'].click(fn=lambda: gr.update(visible=False), inputs=None, outputs=ui_components['metadata_modal'])
121
+
122
+ # Queue Manager Events
123
+ ui_components['add_task_button'].click(fn=queue_manager.add_or_update_task_in_queue, inputs=[ui_components['app_state']] + task_defining_ui_inputs, outputs=[ui_components['app_state'], ui_components['queue_df_display_ui'], ui_components['add_task_button'], ui_components['cancel_edit_task_button']])
124
+ ui_components['process_queue_button'].click(fn=queue_manager.process_task_queue_main_loop, inputs=[ui_components['app_state']], outputs=process_queue_outputs_list)
125
+ ui_components['cancel_edit_task_button'].click(fn=queue_manager.cancel_edit_mode_action, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['queue_df_display_ui'], ui_components['add_task_button'], ui_components['cancel_edit_task_button']])
126
+ ui_components['abort_task_button'].click(fn=queue_manager.abort_current_task_processing_action, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['abort_task_button']])
127
+ ui_components['clear_queue_button_ui'].click(fn=queue_manager.clear_task_queue_action, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['queue_df_display_ui']])
128
+ ui_components['save_queue_button_ui'].click(fn=queue_manager.save_queue_to_zip, inputs=[ui_components['app_state']], outputs=[ui_components['app_state'], ui_components['save_queue_zip_b64_output']]).then(fn=None, inputs=[ui_components['save_queue_zip_b64_output']], outputs=None, js="""(b64) => { if(!b64) return; const blob = new Blob([Uint8Array.from(atob(b64), c => c.charCodeAt(0))], {type: 'application/zip'}); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href=url; a.download='goan_queue.zip'; a.click(); URL.revokeObjectURL(url); }""")
129
+ ui_components['load_queue_button_ui'].upload(fn=queue_manager.load_queue_from_zip, inputs=[ui_components['app_state'], ui_components['load_queue_button_ui']], outputs=[ui_components['app_state'], ui_components['queue_df_display_ui']])
130
+ ui_components['queue_df_display_ui'].select(fn=queue_manager.handle_queue_action_on_select, inputs=[ui_components['app_state']] + task_defining_ui_inputs, outputs=queue_df_select_outputs_list)
131
+
132
+ # Other UI Event Handlers
133
+ ui_components['gs_schedule_shape_ui'].change(fn=lambda choice: gr.update(interactive=(choice != "Off")), inputs=[ui_components['gs_schedule_shape_ui']], outputs=[ui_components['gs_final_ui']])
134
+ for ctrl_key in ['total_second_length_ui', 'latent_window_size_ui']:
135
+ ui_components[ctrl_key].change(fn=ui_update_total_segments, inputs=[ui_components['total_second_length_ui'], ui_components['latent_window_size_ui']], outputs=[ui_components['total_segments_display_ui']])
136
+
137
+ refresh_image_path_state = gr.State(None)
138
+ # The reset_ui_button's functionality remains the same: save state then reload page
139
+ ui_components['reset_ui_button'].click(fn=workspace_manager.save_ui_and_image_for_refresh, inputs=task_defining_ui_inputs, outputs=None).then(fn=None, inputs=None, outputs=None, js="() => { window.location.reload(); }")
140
+
141
+ # --- Application Startup and Shutdown ---
142
+ autoload_outputs = [ui_components[k] for k in ['app_state', 'queue_df_display_ui', 'process_queue_button', 'abort_task_button', 'last_finished_video_ui']]
143
+
144
+ # This is the crucial block.load chain to ensure re-attachment
145
+ (block.load(fn=workspace_manager.load_workspace_on_start, inputs=[], outputs=[refresh_image_path_state] + full_workspace_ui_components)
146
+ .then(fn=workspace_manager.load_image_from_path, inputs=[refresh_image_path_state], outputs=[ui_components['input_image_gallery_ui']])
147
+ .then(fn=queue_manager.autoload_queue_on_start_action, inputs=[ui_components['app_state']], outputs=autoload_outputs)
148
+ .then(lambda s_val: shared_state.global_state_for_autosave.update(s_val), inputs=[ui_components['app_state']], outputs=None)
149
+ .then(fn=ui_update_total_segments, inputs=[ui_components['total_second_length_ui'], ui_components['latent_window_size_ui']], outputs=[ui_components['total_segments_display_ui']])
150
+ # *** ADDED: Automatic re-attachment of progress UI on load if processing ***
151
+ .then(
152
+ fn=queue_manager.process_task_queue_main_loop,
153
+ inputs=[ui_components['app_state']],
154
+ outputs=process_queue_outputs_list, # Use the existing output list
155
+ js="""
156
+ (app_state_val) => {
157
+ // This JS runs after autoload_queue_on_start_action completes.
158
+ // If a task is processing, we want to re-invoke the Python generator.
159
+ if (app_state_val.queue_state && app_state_val.queue_state.processing) {
160
+ console.log("Gradio: Auto-reconnecting to ongoing task output stream.");
161
+ // Return a non-null, non-falsey value to trigger the Python function.
162
+ return "reconnect_stream";
163
+ }
164
+ console.log("Gradio: No ongoing task detected for auto-reconnection.");
165
+ return null; // Return null to skip calling the Python function
166
+ }
167
+ """
168
+ )
169
+ )
170
+
171
+ # Register the atexit handler with the global_state_for_autosave
172
+ atexit.register(queue_manager.autosave_queue_on_exit_action, shared_state.global_state_for_autosave)
173
+
174
+ # --- Application Launch ---
175
+ if __name__ == "__main__":
176
+ print("Starting goan FramePack UI...")
177
+ # Determine the initial output folder for allowed_paths based on saved settings or default
178
+ initial_output_folder_path = workspace_manager.get_initial_output_folder_from_settings()
179
+ expanded_outputs_folder_for_launch = os.path.abspath(initial_output_folder_path)
180
+
181
+ # Prepare the list of allowed paths for Gradio
182
+ final_allowed_paths = [expanded_outputs_folder_for_launch]
183
+
184
+ if args.allowed_output_paths:
185
+ custom_cli_paths = [
186
+ os.path.abspath(os.path.expanduser(p.strip()))
187
+ for p in args.allowed_output_paths.split(',')
188
+ if p.strip()
189
+ ]
190
+ final_allowed_paths.extend(custom_cli_paths)
191
+
192
+ final_allowed_paths = list(set(final_allowed_paths)) # Remove duplicates
193
+
194
+ print(f"Gradio allowed paths: {final_allowed_paths}")
195
+ block.launch(share=True)