#!/usr/bin/env python3 import time import argparse import torch from diffusers import FluxPipeline def benchmark_load_lora( base_model: str, lora_source: str, weight_name: str = None, adapter_name: str = None, dtype = torch.bfloat16, runs: int = 3, ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Benchmarking on device {device}, torch.cuda.device_count()={torch.cuda.device_count()}.") print(f"1/4. Loading base Flux.1-dev model …") t0 = time.time() pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=dtype, use_safetensors=True) base_load_s = time.time() - t0 print(f" Base model loaded in {base_load_s:.3f} s") print("2/4. Moving pipeline to GPU …") t1 = time.time() pipe = pipe.to(device) torch.cuda.synchronize(device) move_s = time.time() - t1 print(f" to('cuda') took {move_s:.3f} s") # Warm‑up LoRA caching (optional) for i in range(runs): print(f"3.{i+1}/4. Running load_lora_weights (run {i+1}/{runs}) …") start = time.time() adapter_name = "lora" pipe.load_lora_weights(lora_source, adapter_name=adapter_name) torch.cuda.synchronize(device) duration = time.time() - start print(f" → run {i+1}: load_lora_weights took {duration:.3f} s") if i < runs - 1: print(" Unloading LoRA …") pipe.unload_lora_weights(reset_to_overwritten_params=True) torch.cuda.synchronize(device) print("All runs complete.") avg = duration # last run print(f"☆ Final run time: {avg:.3f} s") print(f"― average over {runs} runs ≈ {avg:.3f} s") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Benchmark Flux.1‑dev load_lora_weights timing" ) parser.add_argument("--model", default="black-forest-labs/FLUX.1-dev") parser.add_argument("--lora", required=True, help="LoRA adapter repo ID or local folder / file path") parser.add_argument("--runs", type=int, default=3) args = parser.parse_args() benchmark_load_lora( base_model=args.model, lora_source=args.lora, runs=args.runs )