varad-simpli commited on
Commit
9be23cd
·
verified ·
1 Parent(s): 463cf7e

Upload benchmark_load_lora.py

Browse files
Files changed (1) hide show
  1. benchmark_load_lora.py +64 -0
benchmark_load_lora.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import time
3
+ import argparse
4
+ import torch
5
+ from diffusers import FluxPipeline
6
+
7
+ def benchmark_load_lora(
8
+ base_model: str,
9
+ lora_source: str,
10
+ weight_name: str = None,
11
+ adapter_name: str = None,
12
+ dtype = torch.bfloat16,
13
+ runs: int = 3,
14
+ ):
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"Benchmarking on device {device}, torch.cuda.device_count()={torch.cuda.device_count()}.")
17
+
18
+ print(f"1/4. Loading base Flux.1-dev model …")
19
+ t0 = time.time()
20
+ pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=dtype, use_safetensors=True)
21
+ base_load_s = time.time() - t0
22
+ print(f" Base model loaded in {base_load_s:.3f} s")
23
+
24
+ print("2/4. Moving pipeline to GPU …")
25
+ t1 = time.time()
26
+ pipe = pipe.to(device)
27
+ torch.cuda.synchronize(device)
28
+ move_s = time.time() - t1
29
+ print(f" to('cuda') took {move_s:.3f} s")
30
+
31
+ # Warm‑up LoRA caching (optional)
32
+ for i in range(runs):
33
+ print(f"3.{i+1}/4. Running load_lora_weights (run {i+1}/{runs}) …")
34
+ start = time.time()
35
+ adapter_name = "lora"
36
+ pipe.load_lora_weights(lora_source, adapter_name=adapter_name)
37
+ torch.cuda.synchronize(device)
38
+ duration = time.time() - start
39
+ print(f" → run {i+1}: load_lora_weights took {duration:.3f} s")
40
+
41
+ if i < runs - 1:
42
+ print(" Unloading LoRA …")
43
+ pipe.unload_lora_weights(reset_to_overwritten_params=True)
44
+ torch.cuda.synchronize(device)
45
+
46
+ print("All runs complete.")
47
+ avg = duration # last run
48
+ print(f"☆ Final run time: {avg:.3f} s")
49
+ print(f"― average over {runs} runs ≈ {avg:.3f} s")
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(
53
+ description="Benchmark Flux.1‑dev load_lora_weights timing"
54
+ )
55
+ parser.add_argument("--model", default="black-forest-labs/FLUX.1-dev")
56
+ parser.add_argument("--lora", required=True, help="LoRA adapter repo ID or local folder / file path")
57
+ parser.add_argument("--runs", type=int, default=3)
58
+ args = parser.parse_args()
59
+
60
+ benchmark_load_lora(
61
+ base_model=args.model,
62
+ lora_source=args.lora,
63
+ runs=args.runs
64
+ )