Upload handler.py
Browse files- handler.py +36 -17
handler.py
CHANGED
@@ -45,22 +45,26 @@ ALLOWED_TIMESTEPS = [1.0, 0.9937, 0.9875, 0.9812, 0.975, 0.9094, 0.725, 0.4219]
|
|
45 |
# Check environment variable for pipeline support
|
46 |
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
47 |
|
48 |
-
def generate_valid_timesteps(num_steps: int, allowed_timesteps: List[float]) -> List[float]:
|
49 |
"""Generate valid timesteps by selecting from the allowed timesteps list"""
|
50 |
if num_steps >= len(allowed_timesteps):
|
51 |
return allowed_timesteps
|
52 |
|
53 |
if num_steps == 1:
|
54 |
-
# For single step, use the highest timestep (most noisy)
|
55 |
-
return [allowed_timesteps[0]]
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
@dataclass
|
66 |
class GenerationConfig:
|
@@ -680,13 +684,28 @@ class EndpointHandler:
|
|
680 |
second_pass["timesteps"] = config.second_pass_timesteps or DEFAULT_SECOND_PASS_TIMESTEPS
|
681 |
elif config.num_inference_steps != 8:
|
682 |
# Case 2: Use num_inference_steps with valid timesteps only
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
else:
|
691 |
# Case 3: Use default optimized timesteps
|
692 |
first_pass["timesteps"] = DEFAULT_FIRST_PASS_TIMESTEPS
|
|
|
45 |
# Check environment variable for pipeline support
|
46 |
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
47 |
|
48 |
+
def generate_valid_timesteps(num_steps: int, allowed_timesteps: List[float], start_high: bool = True) -> List[float]:
|
49 |
"""Generate valid timesteps by selecting from the allowed timesteps list"""
|
50 |
if num_steps >= len(allowed_timesteps):
|
51 |
return allowed_timesteps
|
52 |
|
53 |
if num_steps == 1:
|
54 |
+
# For single step, use the highest timestep (most noisy) if start_high is True
|
55 |
+
return [allowed_timesteps[0] if start_high else allowed_timesteps[-1]]
|
56 |
|
57 |
+
if start_high:
|
58 |
+
# Select evenly spaced timesteps from the allowed list, starting from highest
|
59 |
+
indices = []
|
60 |
+
for i in range(num_steps):
|
61 |
+
idx = int(i * (len(allowed_timesteps) - 1) / (num_steps - 1))
|
62 |
+
indices.append(idx)
|
63 |
+
return [allowed_timesteps[i] for i in indices]
|
64 |
+
else:
|
65 |
+
# For cases where we don't need to start high (like second pass continuation)
|
66 |
+
# Take the last num_steps timesteps
|
67 |
+
return allowed_timesteps[-num_steps:]
|
68 |
|
69 |
@dataclass
|
70 |
class GenerationConfig:
|
|
|
684 |
second_pass["timesteps"] = config.second_pass_timesteps or DEFAULT_SECOND_PASS_TIMESTEPS
|
685 |
elif config.num_inference_steps != 8:
|
686 |
# Case 2: Use num_inference_steps with valid timesteps only
|
687 |
+
if config.num_inference_steps <= 4:
|
688 |
+
# For very few steps, use a simple split with key timesteps
|
689 |
+
if config.num_inference_steps == 1:
|
690 |
+
first_pass["timesteps"] = [1.0]
|
691 |
+
second_pass["timesteps"] = [0.4219]
|
692 |
+
elif config.num_inference_steps == 2:
|
693 |
+
first_pass["timesteps"] = [1.0]
|
694 |
+
second_pass["timesteps"] = [0.9094, 0.4219]
|
695 |
+
elif config.num_inference_steps == 3:
|
696 |
+
first_pass["timesteps"] = [1.0, 0.9094]
|
697 |
+
second_pass["timesteps"] = [0.9094, 0.4219]
|
698 |
+
else: # 4 steps
|
699 |
+
first_pass["timesteps"] = [1.0, 0.975, 0.9094]
|
700 |
+
second_pass["timesteps"] = [0.9094, 0.725, 0.4219]
|
701 |
+
else:
|
702 |
+
# For more steps, split them properly
|
703 |
+
first_pass_steps = max(1, int(config.num_inference_steps * 0.7))
|
704 |
+
second_pass_steps = max(1, config.num_inference_steps - first_pass_steps)
|
705 |
+
|
706 |
+
# Generate valid timesteps for each pass
|
707 |
+
first_pass["timesteps"] = generate_valid_timesteps(first_pass_steps, ALLOWED_TIMESTEPS)
|
708 |
+
second_pass["timesteps"] = generate_valid_timesteps(second_pass_steps, ALLOWED_TIMESTEPS)
|
709 |
else:
|
710 |
# Case 3: Use default optimized timesteps
|
711 |
first_pass["timesteps"] = DEFAULT_FIRST_PASS_TIMESTEPS
|