alex commited on
Commit
dab4621
·
1 Parent(s): 282214b
Files changed (1) hide show
  1. app.py +12 -22
app.py CHANGED
@@ -32,10 +32,7 @@ try:
32
 
33
  sh(f"pip install {flash_attention_wheel}")
34
  print("Attempting to download and install FlashAttention wheel...")
35
- # sh("pip install flash-attn")
36
- sh("pip install --no-build-isolation transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl")
37
 
38
- # tell Python to re-scan site-packages now that the egg-link exists
39
  import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
40
 
41
  flash_attention_installed = True
@@ -54,7 +51,6 @@ try:
54
  sh(f"pip install {te_wheel}")
55
  print("Attempting to download and install Transformer Engine wheel...")
56
 
57
- # tell Python to re-scan site-packages now that the egg-link exists
58
  import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
59
 
60
  except Exception as e:
@@ -123,10 +119,10 @@ def calculate_required_time(steps, max_duration):
123
  warmup_s = 60
124
 
125
  max_duration_duration_mapping = {
126
- 1: 8,
127
- 2: 8,
128
- 3: 12,
129
- 4: 20,
130
  }
131
  each_step_s = max_duration_duration_mapping[max_duration]
132
  duration_s = (each_step_s * steps) + warmup_s
@@ -247,12 +243,6 @@ def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration
247
  filename = f"gen_{uuid.uuid4().hex[:10]}"
248
  width, height = 832, 480
249
 
250
- duration_frame_mapping = {
251
- 1:25,
252
- 2:45,
253
- 3:70,
254
- 4:97
255
- }
256
 
257
  # Run inference
258
  runner.inference_loop(
@@ -265,7 +255,7 @@ def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration
265
  width,
266
  height,
267
  steps,
268
- frames = int(duration_frame_mapping[max_duration]),
269
  tea_cache_l1_thresh = 0.0,
270
  )
271
 
@@ -345,7 +335,7 @@ with gr.Blocks(css=css) as demo:
345
  default_steps = 10
346
  default_max_duration = 3
347
 
348
- max_duration = gr.Slider(minimum=2, maximum=4, value=default_max_duration, step=1, label="Max Duration")
349
  steps_input = gr.Slider(minimum=10, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
350
 
351
 
@@ -376,7 +366,7 @@ with gr.Blocks(css=css) as demo:
376
  gr.Markdown("")
377
  time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
378
  run_btn = gr.Button("🎬 Action", variant="primary")
379
-
380
  gr.Examples(
381
  examples=[
382
 
@@ -386,7 +376,7 @@ with gr.Blocks(css=css) as demo:
386
  10,
387
  ["./examples/naomi.png"],
388
  "./examples/science.wav",
389
- 3,
390
  ],
391
 
392
  [
@@ -394,15 +384,15 @@ with gr.Blocks(css=css) as demo:
394
  10,
395
  ["./examples/art.png"],
396
  "./examples/art.wav",
397
- 2,
398
  ],
399
 
400
  [
401
  "A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
402
- 5,
403
  ["./examples/naomi.png"],
404
  "./examples/dream.mp3",
405
- 5,
406
  ],
407
 
408
  [
@@ -410,7 +400,7 @@ with gr.Blocks(css=css) as demo:
410
  40,
411
  ["./examples/amber.png", "./examples/jacket.png"],
412
  "./examples/fictional.wav",
413
- 4,
414
  ],
415
 
416
  ],
 
32
 
33
  sh(f"pip install {flash_attention_wheel}")
34
  print("Attempting to download and install FlashAttention wheel...")
 
 
35
 
 
36
  import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
37
 
38
  flash_attention_installed = True
 
51
  sh(f"pip install {te_wheel}")
52
  print("Attempting to download and install Transformer Engine wheel...")
53
 
 
54
  import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
55
 
56
  except Exception as e:
 
119
  warmup_s = 60
120
 
121
  max_duration_duration_mapping = {
122
+ 20: 8,
123
+ 45: 8,
124
+ 70: 12,
125
+ 95: 20,
126
  }
127
  each_step_s = max_duration_duration_mapping[max_duration]
128
  duration_s = (each_step_s * steps) + warmup_s
 
243
  filename = f"gen_{uuid.uuid4().hex[:10]}"
244
  width, height = 832, 480
245
 
 
 
 
 
 
 
246
 
247
  # Run inference
248
  runner.inference_loop(
 
255
  width,
256
  height,
257
  steps,
258
+ frames = int(max_duration),
259
  tea_cache_l1_thresh = 0.0,
260
  )
261
 
 
335
  default_steps = 10
336
  default_max_duration = 3
337
 
338
+ max_duration = gr.Slider(minimum=45, maximum=95, value=default_max_duration, step=25, label="Frames")
339
  steps_input = gr.Slider(minimum=10, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
340
 
341
 
 
366
  gr.Markdown("")
367
  time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
368
  run_btn = gr.Button("🎬 Action", variant="primary")
369
+
370
  gr.Examples(
371
  examples=[
372
 
 
376
  10,
377
  ["./examples/naomi.png"],
378
  "./examples/science.wav",
379
+ 70,
380
  ],
381
 
382
  [
 
384
  10,
385
  ["./examples/art.png"],
386
  "./examples/art.wav",
387
+ 45,
388
  ],
389
 
390
  [
391
  "A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
392
+ 10,
393
  ["./examples/naomi.png"],
394
  "./examples/dream.mp3",
395
+ 95,
396
  ],
397
 
398
  [
 
400
  40,
401
  ["./examples/amber.png", "./examples/jacket.png"],
402
  "./examples/fictional.wav",
403
+ 70,
404
  ],
405
 
406
  ],