Spaces:
Paused
Paused
hf
Browse files- app.py +9 -14
- app_onnx.py +12 -19
- midi_model.py +53 -23
- midi_tokenizer.py +29 -0
app.py
CHANGED
@@ -365,19 +365,19 @@ if __name__ == "__main__":
|
|
365 |
synthesizer = MidiSynthesizer(soundfont_path)
|
366 |
models_info = {
|
367 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
368 |
-
"skytnt/midi-model-tv2o-medium",
|
369 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
370 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
371 |
}
|
372 |
],
|
373 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
374 |
-
"asigalov61/Music-Llama",
|
375 |
],
|
376 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
377 |
-
"asigalov61/Music-Llama-Medium",
|
378 |
],
|
379 |
"generic pretrain model (tv1-medium) by skytnt": [
|
380 |
-
"skytnt/midi-model",
|
381 |
]
|
382 |
}
|
383 |
models = {}
|
@@ -388,20 +388,15 @@ if __name__ == "__main__":
|
|
388 |
torch.backends.cudnn.allow_tf32 = True
|
389 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
390 |
torch.backends.cuda.enable_flash_sdp(True)
|
391 |
-
for name, (repo_id,
|
392 |
-
|
393 |
-
model
|
394 |
-
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
395 |
-
state_dict = ckpt.get("state_dict", ckpt)
|
396 |
-
model.load_state_dict(state_dict, strict=False)
|
397 |
-
model.to(device="cpu", dtype=torch.float32).eval()
|
398 |
models[name] = model
|
399 |
for lora_name, lora_repo in loras.items():
|
400 |
-
model = MIDIModel
|
401 |
-
model.load_state_dict(state_dict, strict=False)
|
402 |
print(f"loading lora {lora_repo} for {name}")
|
403 |
model = model.load_merge_lora(lora_repo)
|
404 |
-
model.to(device="cpu", dtype=torch.float32)
|
405 |
models[f"{name} with {lora_name} lora"] = model
|
406 |
|
407 |
load_javascript()
|
|
|
365 |
synthesizer = MidiSynthesizer(soundfont_path)
|
366 |
models_info = {
|
367 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
368 |
+
"skytnt/midi-model-tv2o-medium", {
|
369 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
370 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
371 |
}
|
372 |
],
|
373 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
374 |
+
"asigalov61/Music-Llama", {}
|
375 |
],
|
376 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
377 |
+
"asigalov61/Music-Llama-Medium", {}
|
378 |
],
|
379 |
"generic pretrain model (tv1-medium) by skytnt": [
|
380 |
+
"skytnt/midi-model", {}
|
381 |
]
|
382 |
}
|
383 |
models = {}
|
|
|
388 |
torch.backends.cudnn.allow_tf32 = True
|
389 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
390 |
torch.backends.cuda.enable_flash_sdp(True)
|
391 |
+
for name, (repo_id, loras) in models_info.items():
|
392 |
+
model = MIDIModel.from_pretrained(repo_id)
|
393 |
+
model.to(device="cpu", dtype=torch.float32)
|
|
|
|
|
|
|
|
|
394 |
models[name] = model
|
395 |
for lora_name, lora_repo in loras.items():
|
396 |
+
model = MIDIModel.from_pretrained(repo_id)
|
|
|
397 |
print(f"loading lora {lora_repo} for {name}")
|
398 |
model = model.load_merge_lora(lora_repo)
|
399 |
+
model.to(device="cpu", dtype=torch.float32)
|
400 |
models[f"{name} with {lora_name} lora"] = model
|
401 |
|
402 |
load_javascript()
|
app_onnx.py
CHANGED
@@ -432,18 +432,12 @@ def hf_hub_download_retry(repo_id, filename):
|
|
432 |
raise err
|
433 |
|
434 |
|
435 |
-
def get_tokenizer(
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
else:
|
442 |
-
o = False
|
443 |
-
if tv not in ["v1", "v2"]:
|
444 |
-
raise ValueError(f"Unknown tokenizer version {tv}")
|
445 |
-
tokenizer = MIDITokenizer(tv)
|
446 |
-
tokenizer.set_optimise_midi(o)
|
447 |
return tokenizer
|
448 |
|
449 |
|
@@ -468,34 +462,33 @@ if __name__ == "__main__":
|
|
468 |
synthesizer = MidiSynthesizer(soundfont_path)
|
469 |
models_info = {
|
470 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
471 |
-
"skytnt/midi-model-tv2o-medium", "",
|
472 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
473 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
474 |
}
|
475 |
],
|
476 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
477 |
-
"asigalov61/Music-Llama", "",
|
478 |
],
|
479 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
480 |
-
"asigalov61/Music-Llama-Medium", "",
|
481 |
],
|
482 |
"generic pretrain model (tv1-medium) by skytnt": [
|
483 |
-
"skytnt/midi-model", "",
|
484 |
]
|
485 |
}
|
486 |
models = {}
|
487 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
488 |
device = "cuda"
|
489 |
|
490 |
-
for name, (repo_id, path,
|
491 |
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
492 |
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
493 |
-
tokenizer = get_tokenizer(
|
494 |
models[name] = [model_base_path, model_token_path, tokenizer]
|
495 |
for lora_name, lora_repo in loras.items():
|
496 |
model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
|
497 |
model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
|
498 |
-
tokenizer = get_tokenizer(config)
|
499 |
models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
|
500 |
|
501 |
load_javascript()
|
|
|
432 |
raise err
|
433 |
|
434 |
|
435 |
+
def get_tokenizer(repo_id):
|
436 |
+
config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
|
437 |
+
with open(config_path, "r") as f:
|
438 |
+
config = json.load(f)
|
439 |
+
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
|
440 |
+
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
return tokenizer
|
442 |
|
443 |
|
|
|
462 |
synthesizer = MidiSynthesizer(soundfont_path)
|
463 |
models_info = {
|
464 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
465 |
+
"skytnt/midi-model-tv2o-medium", "", {
|
466 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
467 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
468 |
}
|
469 |
],
|
470 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
471 |
+
"asigalov61/Music-Llama", "", {}
|
472 |
],
|
473 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
474 |
+
"asigalov61/Music-Llama-Medium", "", {}
|
475 |
],
|
476 |
"generic pretrain model (tv1-medium) by skytnt": [
|
477 |
+
"skytnt/midi-model", "", {}
|
478 |
]
|
479 |
}
|
480 |
models = {}
|
481 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
482 |
device = "cuda"
|
483 |
|
484 |
+
for name, (repo_id, path, loras) in models_info.items():
|
485 |
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
486 |
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
487 |
+
tokenizer = get_tokenizer(repo_id)
|
488 |
models[name] = [model_base_path, model_token_path, tokenizer]
|
489 |
for lora_name, lora_repo in loras.items():
|
490 |
model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
|
491 |
model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
|
|
|
492 |
models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
|
493 |
|
494 |
load_javascript()
|
midi_model.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
@@ -6,21 +7,57 @@ import torch.nn as nn
|
|
6 |
import torch.nn.functional as F
|
7 |
import tqdm
|
8 |
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
9 |
-
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
10 |
-
from transformers.integrations import PeftAdapterMixin
|
11 |
|
12 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
13 |
|
14 |
config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
|
15 |
|
16 |
|
17 |
-
class MIDIModelConfig:
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
@staticmethod
|
26 |
def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
|
@@ -59,27 +96,20 @@ class MIDIModelConfig:
|
|
59 |
raise ValueError(f"Unknown model size {size}")
|
60 |
|
61 |
|
62 |
-
class MIDIModel(
|
|
|
|
|
63 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
64 |
-
super(MIDIModel, self).__init__()
|
65 |
self.tokenizer = config.tokenizer
|
66 |
self.net = LlamaModel(config.net_config)
|
67 |
self.net_token = LlamaModel(config.net_token_config)
|
68 |
self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
|
69 |
-
self.device = "cpu"
|
70 |
-
|
71 |
-
def to(self, *args, **kwargs):
|
72 |
-
if "device" in kwargs:
|
73 |
-
self.device = kwargs["device"]
|
74 |
-
return super(MIDIModel, self).to(*args, **kwargs)
|
75 |
-
|
76 |
-
def peft_loaded(self):
|
77 |
-
return self._hf_peft_config_loaded
|
78 |
|
79 |
def load_merge_lora(self, model_id):
|
80 |
peft_config = PeftConfig.from_pretrained(model_id)
|
81 |
model = LoraModel(self, peft_config, adapter_name="default")
|
82 |
-
adapter_state_dict = load_peft_weights(model_id, device=self.device)
|
83 |
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
84 |
return model.merge_and_unload()
|
85 |
|
@@ -164,7 +194,7 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
164 |
with bar:
|
165 |
while cur_len < max_len:
|
166 |
end = [False] * batch_size
|
167 |
-
hidden = self.forward(input_tensor[:,past_len:], cache=cache1)[:, -1]
|
168 |
next_token_seq = None
|
169 |
event_names = [""] * batch_size
|
170 |
cache2 = DynamicCache()
|
|
|
1 |
+
import json
|
2 |
+
from typing import Union, Dict, Any
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
7 |
import torch.nn.functional as F
|
8 |
import tqdm
|
9 |
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
10 |
+
from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel
|
|
|
11 |
|
12 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
13 |
|
14 |
config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
|
15 |
|
16 |
|
17 |
+
class MIDIModelConfig(PretrainedConfig):
|
18 |
+
model_type = "midi_model"
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
|
22 |
+
net_config: Union[LlamaConfig, Dict]=None,
|
23 |
+
net_token_config: Union[LlamaConfig, Dict]=None,
|
24 |
+
**kwargs):
|
25 |
+
super().__init__(**kwargs)
|
26 |
+
if tokenizer:
|
27 |
+
if isinstance(tokenizer, dict):
|
28 |
+
self.tokenizer = MIDITokenizer(tokenizer["version"])
|
29 |
+
self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"])
|
30 |
+
else:
|
31 |
+
self.tokenizer = tokenizer
|
32 |
+
else:
|
33 |
+
self.tokenizer = MIDITokenizer()
|
34 |
+
if net_config:
|
35 |
+
if isinstance(net_config, dict):
|
36 |
+
self.net_config = LlamaConfig(**net_config)
|
37 |
+
else:
|
38 |
+
self.net_config = net_config
|
39 |
+
else:
|
40 |
+
self.net_config = LlamaConfig()
|
41 |
+
if net_token_config:
|
42 |
+
if isinstance(net_token_config, dict):
|
43 |
+
self.net_token_config = LlamaConfig(**net_token_config)
|
44 |
+
else:
|
45 |
+
self.net_token_config = net_token_config
|
46 |
+
else:
|
47 |
+
self.net_token_config = LlamaConfig()
|
48 |
+
self.n_embd = self.net_token_config.hidden_size
|
49 |
+
|
50 |
+
def to_dict(self) -> Dict[str, Any]:
|
51 |
+
d = super().to_dict()
|
52 |
+
d["tokenizer"] = self.tokenizer.to_dict()
|
53 |
+
return d
|
54 |
+
|
55 |
+
def __str__(self):
|
56 |
+
d = {
|
57 |
+
"net": self.net_config.to_json_string(use_diff=False),
|
58 |
+
"net_token": self.net_token_config.to_json_string(use_diff=False)
|
59 |
+
}
|
60 |
+
return json.dumps(d, indent=4)
|
61 |
|
62 |
@staticmethod
|
63 |
def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
|
|
|
96 |
raise ValueError(f"Unknown model size {size}")
|
97 |
|
98 |
|
99 |
+
class MIDIModel(PreTrainedModel):
|
100 |
+
config_class = MIDIModelConfig
|
101 |
+
|
102 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
103 |
+
super(MIDIModel, self).__init__(config, *args, **kwargs)
|
104 |
self.tokenizer = config.tokenizer
|
105 |
self.net = LlamaModel(config.net_config)
|
106 |
self.net_token = LlamaModel(config.net_token_config)
|
107 |
self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
def load_merge_lora(self, model_id):
|
110 |
peft_config = PeftConfig.from_pretrained(model_id)
|
111 |
model = LoraModel(self, peft_config, adapter_name="default")
|
112 |
+
adapter_state_dict = load_peft_weights(model_id, device=str(self.device))
|
113 |
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
114 |
return model.merge_and_unload()
|
115 |
|
|
|
194 |
with bar:
|
195 |
while cur_len < max_len:
|
196 |
end = [False] * batch_size
|
197 |
+
hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
|
198 |
next_token_seq = None
|
199 |
event_names = [""] * batch_size
|
200 |
cache2 = DynamicCache()
|
midi_tokenizer.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import random
|
|
|
2 |
|
3 |
import PIL.Image
|
4 |
import numpy as np
|
@@ -33,6 +34,20 @@ class MIDITokenizerV1:
|
|
33 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
34 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def set_optimise_midi(self, optimise_midi=True):
|
37 |
self.optimise_midi = optimise_midi
|
38 |
|
@@ -519,6 +534,20 @@ class MIDITokenizerV2:
|
|
519 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
520 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
def set_optimise_midi(self, optimise_midi=True):
|
523 |
self.optimise_midi = optimise_midi
|
524 |
|
|
|
1 |
import random
|
2 |
+
from typing import Dict, Any
|
3 |
|
4 |
import PIL.Image
|
5 |
import numpy as np
|
|
|
34 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
35 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
36 |
|
37 |
+
def to_dict(self) -> Dict[str, Any]:
|
38 |
+
d = {
|
39 |
+
"version":self.version,
|
40 |
+
"optimise_midi":self.optimise_midi,
|
41 |
+
"vocab_size": self.vocab_size,
|
42 |
+
"events": self.events,
|
43 |
+
"event_parameters": self.event_parameters,
|
44 |
+
"max_token_seq": self.max_token_seq,
|
45 |
+
"pad_id": self.pad_id,
|
46 |
+
"bos_id": self.bos_id,
|
47 |
+
"eos_id": self.eos_id,
|
48 |
+
}
|
49 |
+
return d
|
50 |
+
|
51 |
def set_optimise_midi(self, optimise_midi=True):
|
52 |
self.optimise_midi = optimise_midi
|
53 |
|
|
|
534 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
535 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
536 |
|
537 |
+
def to_dict(self) -> Dict[str, Any]:
|
538 |
+
d = {
|
539 |
+
"version":self.version,
|
540 |
+
"optimise_midi":self.optimise_midi,
|
541 |
+
"vocab_size": self.vocab_size,
|
542 |
+
"events": self.events,
|
543 |
+
"event_parameters": self.event_parameters,
|
544 |
+
"max_token_seq": self.max_token_seq,
|
545 |
+
"pad_id": self.pad_id,
|
546 |
+
"bos_id": self.bos_id,
|
547 |
+
"eos_id": self.eos_id,
|
548 |
+
}
|
549 |
+
return d
|
550 |
+
|
551 |
def set_optimise_midi(self, optimise_midi=True):
|
552 |
self.optimise_midi = optimise_midi
|
553 |
|