jbilcke-hf HF Staff commited on
Commit
5e47316
·
verified ·
1 Parent(s): 16bf4ef

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -2
handler.py CHANGED
@@ -703,9 +703,18 @@ class EndpointHandler:
703
  first_pass_steps = max(1, int(config.num_inference_steps * 0.7))
704
  second_pass_steps = max(1, config.num_inference_steps - first_pass_steps)
705
 
706
- # Generate valid timesteps for each pass
707
  first_pass["timesteps"] = generate_valid_timesteps(first_pass_steps, ALLOWED_TIMESTEPS)
708
- second_pass["timesteps"] = generate_valid_timesteps(second_pass_steps, ALLOWED_TIMESTEPS)
 
 
 
 
 
 
 
 
 
709
  else:
710
  # Case 3: Use default optimized timesteps
711
  first_pass["timesteps"] = DEFAULT_FIRST_PASS_TIMESTEPS
 
703
  first_pass_steps = max(1, int(config.num_inference_steps * 0.7))
704
  second_pass_steps = max(1, config.num_inference_steps - first_pass_steps)
705
 
706
+ # Generate valid timesteps for first pass (starts at 1.0)
707
  first_pass["timesteps"] = generate_valid_timesteps(first_pass_steps, ALLOWED_TIMESTEPS)
708
+
709
+ # Second pass should start at high noise level (like 0.9094) but NOT 1.0
710
+ # Find a good starting point in the allowed timesteps (around 0.9094)
711
+ start_idx = 5 # This is 0.9094 in ALLOWED_TIMESTEPS
712
+ if second_pass_steps == 1:
713
+ second_pass["timesteps"] = [ALLOWED_TIMESTEPS[start_idx]] # Just 0.9094
714
+ else:
715
+ # Start from 0.9094 and go down, taking second_pass_steps timesteps
716
+ end_idx = min(len(ALLOWED_TIMESTEPS), start_idx + second_pass_steps)
717
+ second_pass["timesteps"] = ALLOWED_TIMESTEPS[start_idx:end_idx]
718
  else:
719
  # Case 3: Use default optimized timesteps
720
  first_pass["timesteps"] = DEFAULT_FIRST_PASS_TIMESTEPS