Update src/pipeline.py
Browse files- src/pipeline.py +1 -0
src/pipeline.py
CHANGED
@@ -84,6 +84,7 @@ def load_pipeline() -> Pipeline:
|
|
84 |
# pipeline.vae = torch.compile(vae)
|
85 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
86 |
quantize_(pipeline.vae, float8_weight_only())
|
|
|
87 |
|
88 |
warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
|
89 |
for _ in range(2):
|
|
|
84 |
# pipeline.vae = torch.compile(vae)
|
85 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
86 |
quantize_(pipeline.vae, float8_weight_only())
|
87 |
+
pipeline.to(memory_format=torch.channels_last)
|
88 |
|
89 |
warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
|
90 |
for _ in range(2):
|