Update src/pipeline.py
Browse files- src/pipeline.py +2 -2
src/pipeline.py
CHANGED
@@ -44,7 +44,6 @@ ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
|
|
44 |
def load_pipeline() -> Pipeline:
|
45 |
# model_name = "manbeast3b/flux.1-schnell-full1"
|
46 |
# text_enc_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
|
47 |
-
|
48 |
# text_encoder_2 = T5EncoderModel.from_pretrained(
|
49 |
# model_name,
|
50 |
# revision=text_enc_revision,
|
@@ -83,7 +82,7 @@ def load_pipeline() -> Pipeline:
|
|
83 |
).to("cuda")
|
84 |
# pipeline.vae = torch.compile(vae)
|
85 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
86 |
-
|
87 |
pipeline.to(memory_format=torch.channels_last)
|
88 |
|
89 |
warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
|
@@ -96,6 +95,7 @@ def load_pipeline() -> Pipeline:
|
|
96 |
num_inference_steps=4,
|
97 |
max_sequence_length=256
|
98 |
)
|
|
|
99 |
# pipeline("")
|
100 |
return pipeline
|
101 |
|
|
|
44 |
def load_pipeline() -> Pipeline:
|
45 |
# model_name = "manbeast3b/flux.1-schnell-full1"
|
46 |
# text_enc_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
|
|
|
47 |
# text_encoder_2 = T5EncoderModel.from_pretrained(
|
48 |
# model_name,
|
49 |
# revision=text_enc_revision,
|
|
|
82 |
).to("cuda")
|
83 |
# pipeline.vae = torch.compile(vae)
|
84 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
85 |
+
|
86 |
pipeline.to(memory_format=torch.channels_last)
|
87 |
|
88 |
warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
|
|
|
95 |
num_inference_steps=4,
|
96 |
max_sequence_length=256
|
97 |
)
|
98 |
+
quantize_(pipeline.vae, float8_weight_only())
|
99 |
# pipeline("")
|
100 |
return pipeline
|
101 |
|