Spaces:
Paused
Paused
merge lora into model
Browse files- app.py +13 -15
- midi_model.py +8 -0
app.py
CHANGED
@@ -142,12 +142,7 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
|
|
142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
145 |
-
model
|
146 |
-
if lora_name is None and model.peft_loaded():
|
147 |
-
model.disable_adapters()
|
148 |
-
elif lora_name is not None:
|
149 |
-
model.enable_adapters()
|
150 |
-
model.set_adapter(lora_name)
|
151 |
model.to(device=opt.device)
|
152 |
tokenizer = model.tokenizer
|
153 |
bpm = int(bpm)
|
@@ -258,7 +253,7 @@ def finish_run(model_name, mid_seq):
|
|
258 |
if mid_seq is None:
|
259 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
260 |
return *outputs, []
|
261 |
-
tokenizer = models[model_name]
|
262 |
outputs = []
|
263 |
end_msgs = [create_msg("progress", [0, 0])]
|
264 |
if not os.path.exists("outputs"):
|
@@ -282,7 +277,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
282 |
if (not should_render_audio) or mid_seq is None:
|
283 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
284 |
return tuple(outputs)
|
285 |
-
tokenizer = models[model_name]
|
286 |
outputs = []
|
287 |
if not os.path.exists("outputs"):
|
288 |
os.mkdir("outputs")
|
@@ -293,13 +288,15 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
293 |
audio_futures.append(audio_future)
|
294 |
for future in audio_futures:
|
295 |
outputs.append((44100, future.result()))
|
|
|
|
|
296 |
return tuple(outputs)
|
297 |
|
298 |
|
299 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
300 |
if mid_seq is None or len(continuation_state) < 2:
|
301 |
return mid_seq, continuation_state, send_msgs([])
|
302 |
-
tokenizer = models[model_name]
|
303 |
if isinstance(continuation_state[-1], list):
|
304 |
mid_seq = continuation_state[-1]
|
305 |
else:
|
@@ -399,14 +396,15 @@ if __name__ == "__main__":
|
|
399 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
400 |
state_dict = ckpt.get("state_dict", ckpt)
|
401 |
model.load_state_dict(state_dict, strict=False)
|
402 |
-
for lora_name, lora_repo in loras.items():
|
403 |
-
model.load_adapter(lora_repo, lora_name)
|
404 |
-
if loras:
|
405 |
-
model.disable_adapters()
|
406 |
model.to(device="cpu", dtype=torch.float32).eval()
|
407 |
-
models[name] = model
|
408 |
for lora_name, lora_repo in loras.items():
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
load_javascript()
|
412 |
app = gr.Blocks()
|
|
|
142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
145 |
+
model = models[model_name]
|
|
|
|
|
|
|
|
|
|
|
146 |
model.to(device=opt.device)
|
147 |
tokenizer = model.tokenizer
|
148 |
bpm = int(bpm)
|
|
|
253 |
if mid_seq is None:
|
254 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
255 |
return *outputs, []
|
256 |
+
tokenizer = models[model_name].tokenizer
|
257 |
outputs = []
|
258 |
end_msgs = [create_msg("progress", [0, 0])]
|
259 |
if not os.path.exists("outputs"):
|
|
|
277 |
if (not should_render_audio) or mid_seq is None:
|
278 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
279 |
return tuple(outputs)
|
280 |
+
tokenizer = models[model_name].tokenizer
|
281 |
outputs = []
|
282 |
if not os.path.exists("outputs"):
|
283 |
os.mkdir("outputs")
|
|
|
288 |
audio_futures.append(audio_future)
|
289 |
for future in audio_futures:
|
290 |
outputs.append((44100, future.result()))
|
291 |
+
if OUTPUT_BATCH_SIZE == 1:
|
292 |
+
return outputs[0]
|
293 |
return tuple(outputs)
|
294 |
|
295 |
|
296 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
297 |
if mid_seq is None or len(continuation_state) < 2:
|
298 |
return mid_seq, continuation_state, send_msgs([])
|
299 |
+
tokenizer = models[model_name].tokenizer
|
300 |
if isinstance(continuation_state[-1], list):
|
301 |
mid_seq = continuation_state[-1]
|
302 |
else:
|
|
|
396 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
397 |
state_dict = ckpt.get("state_dict", ckpt)
|
398 |
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
399 |
model.to(device="cpu", dtype=torch.float32).eval()
|
400 |
+
models[name] = model
|
401 |
for lora_name, lora_repo in loras.items():
|
402 |
+
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
403 |
+
model.load_state_dict(state_dict, strict=False)
|
404 |
+
print(f"loading lora {lora_repo} for {name}")
|
405 |
+
model = model.load_merge_lora(lora_repo)
|
406 |
+
model.to(device="cpu", dtype=torch.float32).eval()
|
407 |
+
models[f"{name} with {lora_name} lora"] = model
|
408 |
|
409 |
load_javascript()
|
410 |
app = gr.Blocks()
|
midi_model.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
import tqdm
|
|
|
8 |
from transformers import LlamaModel, LlamaConfig
|
9 |
from transformers.integrations import PeftAdapterMixin
|
10 |
|
@@ -75,6 +76,13 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
75 |
def peft_loaded(self):
|
76 |
return self._hf_peft_config_loaded
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def forward_token(self, hidden_state, x=None):
|
79 |
"""
|
80 |
|
|
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
import tqdm
|
8 |
+
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
9 |
from transformers import LlamaModel, LlamaConfig
|
10 |
from transformers.integrations import PeftAdapterMixin
|
11 |
|
|
|
76 |
def peft_loaded(self):
|
77 |
return self._hf_peft_config_loaded
|
78 |
|
79 |
+
def load_merge_lora(self, model_id):
|
80 |
+
peft_config = PeftConfig.from_pretrained(model_id)
|
81 |
+
model = LoraModel(self, peft_config, adapter_name="default")
|
82 |
+
adapter_state_dict = load_peft_weights(model_id, device=self.device)
|
83 |
+
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
84 |
+
return model.merge_and_unload()
|
85 |
+
|
86 |
def forward_token(self, hidden_state, x=None):
|
87 |
"""
|
88 |
|