Fabrice-TIERCELIN commited on
Commit
3cb3fb8
·
verified ·
1 Parent(s): 0575a63

Create utils/lora_utils.py

Browse files
Files changed (1) hide show
  1. utils/lora_utils.py +234 -0
utils/lora_utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from safetensors.torch import load_file
4
+ from tqdm import tqdm
5
+
6
+
7
+ def merge_lora_to_state_dict(
8
+ state_dict: dict[str, torch.Tensor], lora_file: str, multiplier: float, device: torch.device
9
+ ) -> dict[str, torch.Tensor]:
10
+ """
11
+ Merge LoRA weights into the state dict of a model.
12
+ """
13
+ lora_sd = load_file(lora_file)
14
+
15
+ # Check the format of the LoRA file
16
+ keys = list(lora_sd.keys())
17
+ if keys[0].startswith("lora_unet_"):
18
+ print(f"Musubi Tuner LoRA detected")
19
+ return merge_musubi_tuner(lora_sd, state_dict, multiplier, device)
20
+
21
+ transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules
22
+ lora_suffix = None
23
+ prefix = None
24
+ for key in keys:
25
+ if lora_suffix is None and "lora_A" in key:
26
+ lora_suffix = "lora_A"
27
+ if prefix is None:
28
+ pfx = key.split(".")[0]
29
+ if pfx in transformer_prefixes:
30
+ prefix = pfx
31
+ if lora_suffix is not None and prefix is not None:
32
+ break
33
+
34
+ if lora_suffix == "lora_A" and prefix is not None:
35
+ print(f"Diffusion-pipe (?) LoRA detected")
36
+ return merge_diffusion_pipe_or_something(lora_sd, state_dict, "lora_unet_", multiplier, device)
37
+
38
+ print(f"LoRA file format not recognized: {os.path.basename(lora_file)}")
39
+ return state_dict
40
+
41
+
42
+ def merge_diffusion_pipe_or_something(
43
+ lora_sd: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor], prefix: str, multiplier: float, device: torch.device
44
+ ) -> dict[str, torch.Tensor]:
45
+ """
46
+ Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
47
+ Copy from Musubi Tuner repo.
48
+ """
49
+ # convert from diffusers(?) to default LoRA
50
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
51
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
52
+
53
+ # note: Diffusers has no alpha, so alpha is set to rank
54
+ new_weights_sd = {}
55
+ lora_dims = {}
56
+ for key, weight in lora_sd.items():
57
+ diffusers_prefix, key_body = key.split(".", 1)
58
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
59
+ print(f"unexpected key: {key} in diffusers format")
60
+ continue
61
+
62
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
63
+ new_weights_sd[new_key] = weight
64
+
65
+ lora_name = new_key.split(".")[0] # before first dot
66
+ if lora_name not in lora_dims and "lora_down" in new_key:
67
+ lora_dims[lora_name] = weight.shape[0]
68
+
69
+ # add alpha with rank
70
+ for lora_name, dim in lora_dims.items():
71
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
72
+
73
+ return merge_musubi_tuner(new_weights_sd, state_dict, multiplier, device)
74
+
75
+
76
+ def merge_musubi_tuner(
77
+ lora_sd: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor], multiplier: float, device: torch.device
78
+ ) -> dict[str, torch.Tensor]:
79
+ """
80
+ Merge LoRA weights into the state dict of a model.
81
+ """
82
+ # Check LoRA is for FramePack or for HunyuanVideo
83
+ is_hunyuan = False
84
+ for key in lora_sd.keys():
85
+ if "double_blocks" in key or "single_blocks" in key:
86
+ is_hunyuan = True
87
+ break
88
+ if is_hunyuan:
89
+ print("HunyuanVideo LoRA detected, converting to FramePack format")
90
+ lora_sd = convert_hunyuan_to_framepack(lora_sd)
91
+
92
+ # Merge LoRA weights into the state dict
93
+ print(f"Merging LoRA weights into state dict. multiplier: {multiplier}")
94
+
95
+ # Create module map
96
+ name_to_original_key = {}
97
+ for key in state_dict.keys():
98
+ if key.endswith(".weight"):
99
+ lora_name = key.rsplit(".", 1)[0] # remove trailing ".weight"
100
+ lora_name = "lora_unet_" + lora_name.replace(".", "_")
101
+ if lora_name not in name_to_original_key:
102
+ name_to_original_key[lora_name] = key
103
+
104
+ # Merge LoRA weights
105
+ keys = list([k for k in lora_sd.keys() if "lora_down" in k])
106
+ for key in tqdm(keys, desc="Merging LoRA weights"):
107
+ up_key = key.replace("lora_down", "lora_up")
108
+ alpha_key = key[: key.index("lora_down")] + "alpha"
109
+
110
+ # find original key for this lora
111
+ module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
112
+ if module_name not in name_to_original_key:
113
+ print(f"No module found for LoRA weight: {key}")
114
+ continue
115
+
116
+ original_key = name_to_original_key[module_name]
117
+
118
+ down_weight = lora_sd[key]
119
+ up_weight = lora_sd[up_key]
120
+
121
+ dim = down_weight.size()[0]
122
+ alpha = lora_sd.get(alpha_key, dim)
123
+ scale = alpha / dim
124
+
125
+ weight = state_dict[original_key]
126
+ original_device = weight.device
127
+ if original_device != device:
128
+ weight = weight.to(device) # to make calculation faster
129
+
130
+ down_weight = down_weight.to(device)
131
+ up_weight = up_weight.to(device)
132
+
133
+ # W <- W + U * D
134
+ if len(weight.size()) == 2:
135
+ # linear
136
+ if len(up_weight.size()) == 4: # use linear projection mismatch
137
+ up_weight = up_weight.squeeze(3).squeeze(2)
138
+ down_weight = down_weight.squeeze(3).squeeze(2)
139
+ weight = weight + multiplier * (up_weight @ down_weight) * scale
140
+ elif down_weight.size()[2:4] == (1, 1):
141
+ # conv2d 1x1
142
+ weight = (
143
+ weight
144
+ + multiplier
145
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
146
+ * scale
147
+ )
148
+ else:
149
+ # conv2d 3x3
150
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
151
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
152
+ weight = weight + multiplier * conved * scale
153
+
154
+ weight = weight.to(original_device) # move back to original device
155
+ state_dict[original_key] = weight
156
+
157
+ return state_dict
158
+
159
+
160
+ def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
161
+ """
162
+ Convert HunyuanVideo LoRA weights to FramePack format.
163
+ """
164
+ new_lora_sd = {}
165
+ for key, weight in lora_sd.items():
166
+ if "double_blocks" in key:
167
+ key = key.replace("double_blocks", "transformer_blocks")
168
+ key = key.replace("img_mod_linear", "norm1_linear")
169
+ key = key.replace("img_attn_qkv", "attn_to_QKV") # split later
170
+ key = key.replace("img_attn_proj", "attn_to_out_0")
171
+ key = key.replace("img_mlp_fc1", "ff_net_0_proj")
172
+ key = key.replace("img_mlp_fc2", "ff_net_2")
173
+ key = key.replace("txt_mod_linear", "norm1_context_linear")
174
+ key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later
175
+ key = key.replace("txt_attn_proj", "attn_to_add_out")
176
+ key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
177
+ key = key.replace("txt_mlp_fc2", "ff_context_net_2")
178
+ elif "single_blocks" in key:
179
+ key = key.replace("single_blocks", "single_transformer_blocks")
180
+ key = key.replace("linear1", "attn_to_QKVM") # split later
181
+ key = key.replace("linear2", "proj_out")
182
+ key = key.replace("modulation_linear", "norm_linear")
183
+ else:
184
+ print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
185
+ continue
186
+
187
+ if "QKVM" in key:
188
+ # split QKVM into Q, K, V, M
189
+ key_q = key.replace("QKVM", "q")
190
+ key_k = key.replace("QKVM", "k")
191
+ key_v = key.replace("QKVM", "v")
192
+ key_m = key.replace("attn_to_QKVM", "proj_mlp")
193
+ if "_down" in key or "alpha" in key:
194
+ # copy QKVM weight or alpha to Q, K, V, M
195
+ assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
196
+ new_lora_sd[key_q] = weight
197
+ new_lora_sd[key_k] = weight
198
+ new_lora_sd[key_v] = weight
199
+ new_lora_sd[key_m] = weight
200
+ elif "_up" in key:
201
+ # split QKVM weight into Q, K, V, M
202
+ assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
203
+ new_lora_sd[key_q] = weight[:3072]
204
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
205
+ new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
206
+ new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288
207
+ else:
208
+ print(f"Unsupported module name: {key}")
209
+ continue
210
+ elif "QKV" in key:
211
+ # split QKV into Q, K, V
212
+ key_q = key.replace("QKV", "q")
213
+ key_k = key.replace("QKV", "k")
214
+ key_v = key.replace("QKV", "v")
215
+ if "_down" in key or "alpha" in key:
216
+ # copy QKV weight or alpha to Q, K, V
217
+ assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
218
+ new_lora_sd[key_q] = weight
219
+ new_lora_sd[key_k] = weight
220
+ new_lora_sd[key_v] = weight
221
+ elif "_up" in key:
222
+ # split QKV weight into Q, K, V
223
+ assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
224
+ new_lora_sd[key_q] = weight[:3072]
225
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
226
+ new_lora_sd[key_v] = weight[3072 * 2 :]
227
+ else:
228
+ print(f"Unsupported module name: {key}")
229
+ continue
230
+ else:
231
+ # no split needed
232
+ new_lora_sd[key] = weight
233
+
234
+ return new_lora_sd