Why is the inference speed of flux.1-schnell-FP8 slower than that of BF16?

#42
by mengjj - opened

There are my codes, have any errors?
def load_model(self, model_path, model_path_fp8, torch_dtype=torch.bfloat16):

    transformer = FluxTransformer2DModel.from_single_file(model_path_fp8, config='FLUX.1-schnell/transformer/config.json', torch_dtype=torch_dtype)
    quantize(transformer, weights=qfloat8)
    freeze(transformer)

    text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=torch_dtype)
    quantize(text_encoder_2, weights=qfloat8)
    freeze(text_encoder_2)
    
    self.pipe = FluxPipeline.from_pretrained(model_path, transformer=None, text_encoder_2=None, torch_dtype=torch_dtype)
    self.pipe.transformer = transformer
    self.pipe.text_encoder_2 = text_encoder_2
    self.pipe.to("cuda")

def forward(self, args) -> dict:
    config = self.configs_s.copy()
    for key in ['model_path', 'model_path_fp8', 'prompt', 'height', 'width', 'guidance_scale', 'num_inference_steps', 'max_sequence_length', 'save_path']:
        if getattr(args, key, None) is not None:
            config[key] = getattr(args, key)

    # self.load_model(config['model_path'])
            
    t2i_start_time = time.time()
    image = self.pipe(
        prompt=config['prompt'],
        guidance_scale=config.get('guidance_scale', 0.0),
        height=config.get('height', 720),
        width=config.get('width', 1280),
        num_inference_steps=config.get('num_inference_steps', 4),
        max_sequence_length=config.get('max_sequence_length', 256),
        generator=torch.Generator("cpu").manual_seed(1997)
    ).images[0]
    t2i_end_time = time.time()
    print(f"Text2Image time is: {t2i_end_time - t2i_start_time}s")
    save_path = config.get('save_path', 't2i_2.png')
    # print(f"Save image to {save_path}")
    image.save(save_path)

    ps_start_time = time.time()
    self.postprocess(save_path)
    ps_end_time = time.time()
    print(f"Post Process Time is: {ps_end_time - ps_start_time}s")

    return {"done"}

Sign up or log in to comment