markury commited on
Commit
cf345e6
·
1 Parent(s): b17912c
Files changed (2) hide show
  1. app.py +228 -4
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,231 @@
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import gradio as gr
3
+ import spaces
4
+ from diffusers.utils import export_to_video
5
+ from diffusers import AutoencoderKLWan, WanPipeline
6
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
7
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
8
 
9
+ # Define model options
10
+ MODEL_OPTIONS = {
11
+ "Wan2.1-T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
12
+ "Wan2.1-T2V-14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers"
13
+ }
14
 
15
+ # Define scheduler options
16
+ SCHEDULER_OPTIONS = {
17
+ "UniPCMultistepScheduler": UniPCMultistepScheduler,
18
+ "FlowMatchEulerDiscreteScheduler": FlowMatchEulerDiscreteScheduler
19
+ }
20
+
21
+ @spaces.GPU(duration=300) # Set a 5-minute duration for the GPU access
22
+ def generate_video(
23
+ model_choice,
24
+ prompt,
25
+ negative_prompt,
26
+ lora_id,
27
+ lora_scale,
28
+ scheduler_type,
29
+ flow_shift,
30
+ height,
31
+ width,
32
+ num_frames,
33
+ guidance_scale,
34
+ num_inference_steps,
35
+ output_fps
36
+ ):
37
+ """Generate a video using the Wan model and provided parameters"""
38
+ try:
39
+ # Get model ID from selection
40
+ model_id = MODEL_OPTIONS[model_choice]
41
+
42
+ # Load the model components
43
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
44
+ pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
45
+
46
+ # Set the scheduler
47
+ scheduler_class = SCHEDULER_OPTIONS[scheduler_type]
48
+
49
+ if scheduler_type == "UniPCMultistepScheduler":
50
+ pipe.scheduler = scheduler_class.from_config(
51
+ pipe.scheduler.config,
52
+ prediction_type="flow_prediction",
53
+ use_flow_sigmas=True,
54
+ flow_shift=flow_shift
55
+ )
56
+ else:
57
+ pipe.scheduler = scheduler_class(shift=flow_shift)
58
+
59
+ # Move to GPU
60
+ pipe.to("cuda")
61
+
62
+ # Enable CPU offload for low VRAM
63
+ pipe.enable_model_cpu_offload()
64
+
65
+ # Load and fuse LoRA if provided
66
+ if lora_id and lora_id.strip():
67
+ try:
68
+ # Load the LoRA weights
69
+ pipe.load_lora_weights(lora_id)
70
+
71
+ # Fuse LoRA with specified scale if available
72
+ if hasattr(pipe, "fuse_lora"):
73
+ pipe.fuse_lora(lora_scale=lora_scale)
74
+ except Exception as e:
75
+ return f"Error loading/fusing LoRA: {str(e)}"
76
+
77
+ # Generate the video
78
+ output = pipe(
79
+ prompt=prompt,
80
+ negative_prompt=negative_prompt,
81
+ height=height,
82
+ width=width,
83
+ num_frames=num_frames,
84
+ guidance_scale=guidance_scale,
85
+ num_inference_steps=num_inference_steps
86
+ ).frames[0]
87
+
88
+ # Export to video
89
+ temp_file = "output.mp4"
90
+ export_to_video(output, temp_file, fps=output_fps)
91
+
92
+ return temp_file
93
+ except Exception as e:
94
+ return f"Error generating video: {str(e)}"
95
+
96
+ # Create the Gradio interface
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown("# Wan Video Generation with ZeroGPU")
99
+ gr.Markdown("Generate high-quality videos using the Wan model with optional LoRA adaptations.")
100
+
101
+ with gr.Row():
102
+ with gr.Column(scale=1):
103
+ model_choice = gr.Dropdown(
104
+ choices=list(MODEL_OPTIONS.keys()),
105
+ value="Wan2.1-T2V-1.3B",
106
+ label="Model"
107
+ )
108
+
109
+ prompt = gr.Textbox(
110
+ label="Prompt",
111
+ value="steamboat willie style, golden era animation, an anthropomorphic cat character wearing a hat removes it and performs a courteous bow",
112
+ lines=3
113
+ )
114
+
115
+ negative_prompt = gr.Textbox(
116
+ label="Negative Prompt",
117
+ value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
118
+ lines=3
119
+ )
120
+
121
+ with gr.Row():
122
+ lora_id = gr.Textbox(
123
+ label="LoRA ID (e.g., benjamin-paine/steamboat-willie-1.3b)",
124
+ value="benjamin-paine/steamboat-willie-1.3b"
125
+ )
126
+ lora_scale = gr.Slider(
127
+ label="LoRA Scale",
128
+ minimum=0.0,
129
+ maximum=1.0,
130
+ value=0.75,
131
+ step=0.05
132
+ )
133
+
134
+ with gr.Row():
135
+ scheduler_type = gr.Dropdown(
136
+ choices=list(SCHEDULER_OPTIONS.keys()),
137
+ value="UniPCMultistepScheduler",
138
+ label="Scheduler"
139
+ )
140
+ flow_shift = gr.Slider(
141
+ label="Flow Shift",
142
+ minimum=1.0,
143
+ maximum=12.0,
144
+ value=3.0,
145
+ step=0.5,
146
+ info="2.0-5.0 for smaller videos, 7.0-12.0 for larger videos"
147
+ )
148
+
149
+ with gr.Row():
150
+ height = gr.Slider(
151
+ label="Height",
152
+ minimum=256,
153
+ maximum=1024,
154
+ value=480,
155
+ step=32
156
+ )
157
+ width = gr.Slider(
158
+ label="Width",
159
+ minimum=256,
160
+ maximum=1792,
161
+ value=832,
162
+ step=32
163
+ )
164
+
165
+ with gr.Row():
166
+ num_frames = gr.Slider(
167
+ label="Number of Frames (4k+1 is recommended, e.g. 81)",
168
+ minimum=17,
169
+ maximum=129,
170
+ value=81,
171
+ step=4
172
+ )
173
+ output_fps = gr.Slider(
174
+ label="Output FPS",
175
+ minimum=8,
176
+ maximum=30,
177
+ value=16,
178
+ step=1
179
+ )
180
+
181
+ with gr.Row():
182
+ guidance_scale = gr.Slider(
183
+ label="Guidance Scale (CFG)",
184
+ minimum=1.0,
185
+ maximum=15.0,
186
+ value=5.0,
187
+ step=0.5
188
+ )
189
+ num_inference_steps = gr.Slider(
190
+ label="Inference Steps",
191
+ minimum=10,
192
+ maximum=100,
193
+ value=32,
194
+ step=1
195
+ )
196
+
197
+ generate_btn = gr.Button("Generate Video")
198
+
199
+ with gr.Column(scale=1):
200
+ output_video = gr.Video(label="Generated Video")
201
+
202
+ generate_btn.click(
203
+ fn=generate_video,
204
+ inputs=[
205
+ model_choice,
206
+ prompt,
207
+ negative_prompt,
208
+ lora_id,
209
+ lora_scale,
210
+ scheduler_type,
211
+ flow_shift,
212
+ height,
213
+ width,
214
+ num_frames,
215
+ guidance_scale,
216
+ num_inference_steps,
217
+ output_fps
218
+ ],
219
+ outputs=output_video
220
+ )
221
+
222
+ gr.Markdown("""
223
+ ## Tips for best results:
224
+ - For smaller resolution videos, try lower values of flow shift (2.0-5.0)
225
+ - For larger resolution videos, try higher values of flow shift (7.0-12.0)
226
+ - Number of frames should be of the form 4k+1 (e.g., 49, 81, 65)
227
+ - The model is memory intensive, so adjust resolution according to available VRAM
228
+ - LoRA ID should be a Hugging Face repository containing safetensors files
229
+ """)
230
+
231
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git
2
+ transformers
3
+ accelerate
4
+ safetensors
5
+ torch>=2.0.1
6
+ gradio
7
+ spaces