cbensimon HF Staff commited on
Commit
4e94151
·
1 Parent(s): 39286c5
Files changed (3) hide show
  1. aoti.py +19 -0
  2. app.py +3 -2
  3. optimization.py +0 -67
aoti.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
7
+ from spaces.zero.torch.aoti import ZeroGPUWeights
8
+
9
+ import fa3
10
+
11
+
12
+ def aoti_load(module: torch.nn.Module, repo_id: str):
13
+ repeated_blocks = module._repeated_blocks
14
+ aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
15
+ for block_name, aoti_file in aoti_files.items():
16
+ for block in module.modules():
17
+ if block.__class__.__name__ == block_name:
18
+ weights = ZeroGPUWeights(block.state_dict())
19
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)
app.py CHANGED
@@ -5,11 +5,12 @@ import spaces
5
  import torch
6
  from diffusers import FluxPipeline
7
 
8
- from optimization import optimize_pipeline_
9
 
10
 
11
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
12
- optimize_pipeline_(pipeline, "prompt")
 
13
 
14
 
15
  @spaces.GPU
 
5
  import torch
6
  from diffusers import FluxPipeline
7
 
8
+ from aoti import aoti_load
9
 
10
 
11
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
12
+ pipeline.transformer.fuse_qkv_projections()
13
+ aoti_load(pipeline.transformer, 'zerogpu-aoti/FLUX.1')
14
 
15
 
16
  @spaces.GPU
optimization.py DELETED
@@ -1,67 +0,0 @@
1
- """
2
- """
3
-
4
- from typing import Any
5
- from typing import Callable
6
- from typing import ParamSpec
7
-
8
- import spaces
9
- import torch
10
- from spaces.zero.torch.aoti import ZeroGPUCompiledModel
11
- from spaces.zero.torch.aoti import ZeroGPUWeights
12
-
13
- from fa3 import FlashFusedFluxAttnProcessor3_0
14
-
15
-
16
- P = ParamSpec('P')
17
-
18
-
19
- INDUCTOR_CONFIGS = {
20
- 'conv_1x1_as_mm': True,
21
- 'epilogue_fusion': False,
22
- 'coordinate_descent_tuning': True,
23
- 'coordinate_descent_check_all_directions': True,
24
- 'max_autotune': True,
25
- 'triton.cudagraphs': True,
26
- }
27
-
28
-
29
- def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
30
-
31
- blocks_A = pipeline.transformer.transformer_blocks
32
- blocks_B = pipeline.transformer.single_transformer_blocks
33
-
34
- @spaces.GPU(duration=1500)
35
- def compile_transformer_block_AB():
36
-
37
- with spaces.aoti_capture(blocks_A[0]) as call_A:
38
- pipeline(*args, **kwargs)
39
-
40
- with spaces.aoti_capture(blocks_B[0]) as call_B:
41
- pipeline(*args, **kwargs)
42
-
43
- exported_A = torch.export.export(
44
- mod=blocks_A[0],
45
- args=call_A.args,
46
- kwargs=call_A.kwargs,
47
- )
48
-
49
- exported_B = torch.export.export(
50
- mod=blocks_B[0],
51
- args=call_B.args,
52
- kwargs=call_B.kwargs,
53
- )
54
-
55
- return (
56
- spaces.aoti_compile(exported_A, INDUCTOR_CONFIGS).archive_file,
57
- spaces.aoti_compile(exported_B, INDUCTOR_CONFIGS).archive_file,
58
- )
59
-
60
- pipeline.transformer.fuse_qkv_projections()
61
- pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
62
-
63
- archive_file_A, archive_file_B = compile_transformer_block_AB()
64
- for blocks, archive_file in zip((blocks_A, blocks_B), (archive_file_A, archive_file_B)):
65
- for block in blocks:
66
- weights = ZeroGPUWeights(block.state_dict())
67
- block.forward = ZeroGPUCompiledModel(archive_file, weights)