jbilcke-hf HF Staff commited on
Commit
16bf4ef
·
verified ·
1 Parent(s): 43c6799

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +36 -17
handler.py CHANGED
@@ -45,22 +45,26 @@ ALLOWED_TIMESTEPS = [1.0, 0.9937, 0.9875, 0.9812, 0.975, 0.9094, 0.725, 0.4219]
45
  # Check environment variable for pipeline support
46
  support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
47
 
48
- def generate_valid_timesteps(num_steps: int, allowed_timesteps: List[float]) -> List[float]:
49
  """Generate valid timesteps by selecting from the allowed timesteps list"""
50
  if num_steps >= len(allowed_timesteps):
51
  return allowed_timesteps
52
 
53
  if num_steps == 1:
54
- # For single step, use the highest timestep (most noisy)
55
- return [allowed_timesteps[0]]
56
 
57
- # Select evenly spaced timesteps from the allowed list
58
- indices = []
59
- for i in range(num_steps):
60
- idx = int(i * (len(allowed_timesteps) - 1) / (num_steps - 1))
61
- indices.append(idx)
62
-
63
- return [allowed_timesteps[i] for i in indices]
 
 
 
 
64
 
65
  @dataclass
66
  class GenerationConfig:
@@ -680,13 +684,28 @@ class EndpointHandler:
680
  second_pass["timesteps"] = config.second_pass_timesteps or DEFAULT_SECOND_PASS_TIMESTEPS
681
  elif config.num_inference_steps != 8:
682
  # Case 2: Use num_inference_steps with valid timesteps only
683
- # Split steps between passes (70% first pass, 30% second pass)
684
- first_pass_steps = max(1, int(config.num_inference_steps * 0.7))
685
- second_pass_steps = max(1, config.num_inference_steps - first_pass_steps)
686
-
687
- # Generate valid timesteps for each pass
688
- first_pass["timesteps"] = generate_valid_timesteps(first_pass_steps, ALLOWED_TIMESTEPS)
689
- second_pass["timesteps"] = generate_valid_timesteps(second_pass_steps, ALLOWED_TIMESTEPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
  else:
691
  # Case 3: Use default optimized timesteps
692
  first_pass["timesteps"] = DEFAULT_FIRST_PASS_TIMESTEPS
 
45
  # Check environment variable for pipeline support
46
  support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
47
 
48
+ def generate_valid_timesteps(num_steps: int, allowed_timesteps: List[float], start_high: bool = True) -> List[float]:
49
  """Generate valid timesteps by selecting from the allowed timesteps list"""
50
  if num_steps >= len(allowed_timesteps):
51
  return allowed_timesteps
52
 
53
  if num_steps == 1:
54
+ # For single step, use the highest timestep (most noisy) if start_high is True
55
+ return [allowed_timesteps[0] if start_high else allowed_timesteps[-1]]
56
 
57
+ if start_high:
58
+ # Select evenly spaced timesteps from the allowed list, starting from highest
59
+ indices = []
60
+ for i in range(num_steps):
61
+ idx = int(i * (len(allowed_timesteps) - 1) / (num_steps - 1))
62
+ indices.append(idx)
63
+ return [allowed_timesteps[i] for i in indices]
64
+ else:
65
+ # For cases where we don't need to start high (like second pass continuation)
66
+ # Take the last num_steps timesteps
67
+ return allowed_timesteps[-num_steps:]
68
 
69
  @dataclass
70
  class GenerationConfig:
 
684
  second_pass["timesteps"] = config.second_pass_timesteps or DEFAULT_SECOND_PASS_TIMESTEPS
685
  elif config.num_inference_steps != 8:
686
  # Case 2: Use num_inference_steps with valid timesteps only
687
+ if config.num_inference_steps <= 4:
688
+ # For very few steps, use a simple split with key timesteps
689
+ if config.num_inference_steps == 1:
690
+ first_pass["timesteps"] = [1.0]
691
+ second_pass["timesteps"] = [0.4219]
692
+ elif config.num_inference_steps == 2:
693
+ first_pass["timesteps"] = [1.0]
694
+ second_pass["timesteps"] = [0.9094, 0.4219]
695
+ elif config.num_inference_steps == 3:
696
+ first_pass["timesteps"] = [1.0, 0.9094]
697
+ second_pass["timesteps"] = [0.9094, 0.4219]
698
+ else: # 4 steps
699
+ first_pass["timesteps"] = [1.0, 0.975, 0.9094]
700
+ second_pass["timesteps"] = [0.9094, 0.725, 0.4219]
701
+ else:
702
+ # For more steps, split them properly
703
+ first_pass_steps = max(1, int(config.num_inference_steps * 0.7))
704
+ second_pass_steps = max(1, config.num_inference_steps - first_pass_steps)
705
+
706
+ # Generate valid timesteps for each pass
707
+ first_pass["timesteps"] = generate_valid_timesteps(first_pass_steps, ALLOWED_TIMESTEPS)
708
+ second_pass["timesteps"] = generate_valid_timesteps(second_pass_steps, ALLOWED_TIMESTEPS)
709
  else:
710
  # Case 3: Use default optimized timesteps
711
  first_pass["timesteps"] = DEFAULT_FIRST_PASS_TIMESTEPS