markury commited on
Commit
10b0bca
·
1 Parent(s): 08b4ec0

fix(wip): second pass

Browse files
Files changed (1) hide show
  1. app.py +61 -38
app.py CHANGED
@@ -7,6 +7,7 @@ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepSchedu
7
  from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
8
  import os
9
  import tempfile
 
10
 
11
  # Define model options
12
  MODEL_OPTIONS = {
@@ -43,7 +44,7 @@ def generate_video(
43
  second_pass_flow_shift,
44
  second_pass_cfg,
45
  show_both_outputs
46
- ):
47
  # Get model ID from selection
48
  model_id = MODEL_OPTIONS[model_choice]
49
 
@@ -98,23 +99,16 @@ def generate_video(
98
  num_frames=num_frames,
99
  guidance_scale=guidance_scale,
100
  num_inference_steps=num_inference_steps,
101
- output_type="latent" if enable_second_pass else "pt", # Only return latents if doing second pass
 
102
  return_dict=True
103
  )
104
 
105
- # Get the latents from the first pass output
106
- latents = first_pass.frames[0]
107
 
108
- # If we're not doing a second pass or need to display both outputs, decode the first pass
109
  if not enable_second_pass or (enable_second_pass and show_both_outputs):
110
- # Decode the latents to frames with the VAE (only needed if we requested latents)
111
- if enable_second_pass:
112
- print("Decoding first pass latents...")
113
- with torch.no_grad():
114
- first_pass_frames = pipe.vae.decode(latents / pipe.vae.config.scaling_factor).sample
115
- else:
116
- first_pass_frames = latents
117
-
118
  # Export first pass to video
119
  first_pass_file = "output_first_pass.mp4"
120
  export_to_video(first_pass_frames, first_pass_file, fps=output_fps)
@@ -125,6 +119,14 @@ def generate_video(
125
  if enable_second_pass:
126
  print("Running second pass with scale factor:", second_pass_scale)
127
 
 
 
 
 
 
 
 
 
128
  # Resize latents for second pass (upscale)
129
  new_height = int(height * second_pass_scale)
130
  new_width = int(width * second_pass_scale)
@@ -135,10 +137,18 @@ def generate_video(
135
 
136
  print(f"Upscaling latents from {height}x{width} to {new_height}x{new_width}")
137
 
 
 
 
 
 
 
 
 
138
  # Upscale latents using interpolate
139
  upscaled_latents = torch.nn.functional.interpolate(
140
  latents,
141
- size=(num_frames, new_height // 8, new_width // 8), # VAE downsamples by factor of 8
142
  mode="trilinear",
143
  align_corners=False
144
  )
@@ -183,15 +193,18 @@ def generate_video(
183
  output_files.append(second_pass_file)
184
 
185
  # Return the appropriate video output(s)
186
- if enable_second_pass and not show_both_outputs:
187
- return second_pass_file
188
- elif enable_second_pass and show_both_outputs:
189
- return [first_pass_file, second_pass_file]
190
  else:
191
- return first_pass_file
192
 
193
- # Create the Gradio interface
194
  with gr.Blocks() as demo:
 
 
 
195
  gr.HTML("""
196
  <p align="center">
197
  <svg version="1.1" viewBox="0 0 1200 295" xmlns="http://www.w3.org/2000/svg" xmlns:v="https://vecta.io/nano" width="400">
@@ -364,34 +377,33 @@ with gr.Blocks() as demo:
364
  output_video = gr.Video(label="Generated Video")
365
  second_output_video = gr.Video(label="Second Pass Video", visible=False)
366
 
367
- # Show/hide second video based on checkbox
368
- def update_second_video_visibility(enable_pass, show_both):
369
- return {"visible": enable_pass and show_both}
370
 
 
371
  enable_second_pass.change(
372
- fn=update_second_video_visibility,
373
  inputs=[enable_second_pass, show_both_outputs],
374
  outputs=[second_output_video]
375
  )
376
 
377
  show_both_outputs.change(
378
- fn=update_second_video_visibility,
379
  inputs=[enable_second_pass, show_both_outputs],
380
  outputs=[second_output_video]
381
  )
382
 
383
- # Updated function to handle the second pass and multiple outputs
384
- def process_generation(*args):
385
- result = generate_video(*args)
386
- if isinstance(result, list) and len(result) > 1:
387
- return [result[0], result[1], {"visible": True}]
388
- elif isinstance(result, list) and len(result) == 1:
389
- return [result[0], None, {"visible": False}]
390
  else:
391
- return [result, None, {"visible": False}]
392
-
 
393
  generate_btn.click(
394
- fn=process_generation,
395
  inputs=[
396
  model_choice,
397
  prompt,
@@ -416,12 +428,23 @@ with gr.Blocks() as demo:
416
  show_both_outputs
417
  ],
418
  outputs=[
419
- output_video,
420
- second_output_video,
421
- second_output_video # Update visibility
422
  ]
423
  )
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  gr.Markdown("""
426
  ## Tips for best results:
427
  - For smaller resolution videos, try lower values of flow shift (2.0-5.0)
 
7
  from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
8
  import os
9
  import tempfile
10
+ from typing import List, Union, Optional
11
 
12
  # Define model options
13
  MODEL_OPTIONS = {
 
44
  second_pass_flow_shift,
45
  second_pass_cfg,
46
  show_both_outputs
47
+ ) -> Union[str, List[str]]:
48
  # Get model ID from selection
49
  model_id = MODEL_OPTIONS[model_choice]
50
 
 
99
  num_frames=num_frames,
100
  guidance_scale=guidance_scale,
101
  num_inference_steps=num_inference_steps,
102
+ # For Wan, we may need to approach this differently for the latents
103
+ output_type="pt", # Always get PyTorch tensors for the first pass
104
  return_dict=True
105
  )
106
 
107
+ # Get the frames or latents from the first pass output
108
+ first_pass_frames = first_pass.frames[0]
109
 
110
+ # Output the first pass video if needed
111
  if not enable_second_pass or (enable_second_pass and show_both_outputs):
 
 
 
 
 
 
 
 
112
  # Export first pass to video
113
  first_pass_file = "output_first_pass.mp4"
114
  export_to_video(first_pass_frames, first_pass_file, fps=output_fps)
 
119
  if enable_second_pass:
120
  print("Running second pass with scale factor:", second_pass_scale)
121
 
122
+ # For second pass, we need to first encode the frames to get latents
123
+ print("Encoding first pass frames to latents...")
124
+ with torch.no_grad():
125
+ # Move frames to the same device as the VAE
126
+ first_pass_frames = first_pass_frames.to(pipe.vae.device)
127
+ # Encode to get latents
128
+ latents = pipe.vae.encode(first_pass_frames).latent_dist.sample()
129
+
130
  # Resize latents for second pass (upscale)
131
  new_height = int(height * second_pass_scale)
132
  new_width = int(width * second_pass_scale)
 
137
 
138
  print(f"Upscaling latents from {height}x{width} to {new_height}x{new_width}")
139
 
140
+ # Get latent dimensions
141
+ latent_height = latents.shape[2] # Should be height//8
142
+ latent_width = latents.shape[3] # Should be width//8
143
+
144
+ # Calculate new latent dimensions
145
+ new_latent_height = new_height // 8
146
+ new_latent_width = new_width // 8
147
+
148
  # Upscale latents using interpolate
149
  upscaled_latents = torch.nn.functional.interpolate(
150
  latents,
151
+ size=(num_frames, new_latent_height, new_latent_width),
152
  mode="trilinear",
153
  align_corners=False
154
  )
 
193
  output_files.append(second_pass_file)
194
 
195
  # Return the appropriate video output(s)
196
+ if enable_second_pass and show_both_outputs and len(output_files) > 1:
197
+ return output_files
198
+ elif len(output_files) > 0:
199
+ return output_files[-1] # Return the last generated output (either first or second pass)
200
  else:
201
+ return "No video was generated. Please check the logs for errors."
202
 
203
+ # Create the Gradio interface
204
  with gr.Blocks() as demo:
205
+ # Import gr.update for visibility control
206
+ from gradio import update
207
+
208
  gr.HTML("""
209
  <p align="center">
210
  <svg version="1.1" viewBox="0 0 1200 295" xmlns="http://www.w3.org/2000/svg" xmlns:v="https://vecta.io/nano" width="400">
 
377
  output_video = gr.Video(label="Generated Video")
378
  second_output_video = gr.Video(label="Second Pass Video", visible=False)
379
 
380
+ # Control visibility through the UI changes directly
381
+ def toggle_second_video(enable_pass, show_both):
382
+ return gr.update(visible=enable_pass and show_both)
383
 
384
+ # Update visibility when checkboxes change
385
  enable_second_pass.change(
386
+ fn=toggle_second_video,
387
  inputs=[enable_second_pass, show_both_outputs],
388
  outputs=[second_output_video]
389
  )
390
 
391
  show_both_outputs.change(
392
+ fn=toggle_second_video,
393
  inputs=[enable_second_pass, show_both_outputs],
394
  outputs=[second_output_video]
395
  )
396
 
397
+ # Define a visibility update function separately
398
+ def update_second_video_visibility(enable_pass, show_both):
399
+ if enable_pass and show_both:
400
+ return gr.update(visible=True)
 
 
 
401
  else:
402
+ return gr.update(visible=False)
403
+
404
+ # Process generation without trying to update visibility in the same function
405
  generate_btn.click(
406
+ fn=generate_video,
407
  inputs=[
408
  model_choice,
409
  prompt,
 
428
  show_both_outputs
429
  ],
430
  outputs=[
431
+ output_video if not show_both_outputs else [output_video, second_output_video]
 
 
432
  ]
433
  )
434
 
435
+ # Update visibility when options change
436
+ enable_second_pass.change(
437
+ fn=update_second_video_visibility,
438
+ inputs=[enable_second_pass, show_both_outputs],
439
+ outputs=[second_output_video]
440
+ )
441
+
442
+ show_both_outputs.change(
443
+ fn=update_second_video_visibility,
444
+ inputs=[enable_second_pass, show_both_outputs],
445
+ outputs=[second_output_video]
446
+ )
447
+
448
  gr.Markdown("""
449
  ## Tips for best results:
450
  - For smaller resolution videos, try lower values of flow shift (2.0-5.0)