from diffusers import ( DiffusionPipeline, AutoencoderKL, FluxPipeline, FluxTransformer2DModel ) from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from huggingface_hub.constants import HF_HUB_CACHE from transformers import ( T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel ) import torch import torch._dynamo import gc from PIL import Image from pipelines.models import TextToImageRequest from torch import Generator import time import math from typing import Type, Dict, Any, Tuple, Callable, Optional, Union import numpy as np import torch.nn as nn import torch.nn.functional as F from torchao.quantization import quantize_, float8_weight_only # preconfigs import os os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" torch._dynamo.config.suppress_errors = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True # torch.backends.cudnn.benchmark = True # globals Pipeline = None ckpt_id = "manbeast3b/flux.1-schnell-full1" ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146" def load_pipeline() -> Pipeline: # model_name = "manbeast3b/flux.1-schnell-full1" # text_enc_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146" # text_encoder_2 = T5EncoderModel.from_pretrained( # model_name, # revision=text_enc_revision, # subfolder="text_encoder_2", # torch_dtype=torch.bfloat16 # ).to(memory_format=torch.channels_last) # vae = AutoencoderKL.from_pretrained( # ckpt_id, # revision=ckpt_revision, # subfolder="vae", # local_files_only=True, # torch_dtype=torch.bfloat16 # ).to(memory_format=torch.channels_last) hub_model_dir = os.path.join( HF_HUB_CACHE, f"models--{ckpt_id.replace('/', '--')}", "snapshots", ckpt_revision, "transformer" ) transformer = FluxTransformer2DModel.from_pretrained( hub_model_dir, torch_dtype=torch.bfloat16, use_safetensors=False ).to(memory_format=torch.channels_last) pipeline = FluxPipeline.from_pretrained( ckpt_id, revision=ckpt_revision, # text_encoder_2=text_encoder_2, transformer=transformer, # vae=vae, torch_dtype=torch.bfloat16 ).to("cuda") # pipeline.vae = torch.compile(vae) pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) pipeline.to(memory_format=torch.channels_last) warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai" for _ in range(2): pipeline( prompt=warmup_, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) quantize_(pipeline.vae, float8_weight_only()) # pipeline("") return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image: return pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]