Sayoyo commited on
Commit
64750a4
ยท
1 Parent(s): 8a8cb3e

update examles

Browse files
examples/default/input_params/output_20250426091716_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "anime, cute female vocals, kawaii pop, j-pop, childish, piano, guitar, synthesizer, fast, happy, cheerful, lighthearted",
3
+ "lyrics": "[Chorus]\nใญใ‡ใ€้ก”ใŒ่ตคใ„ใ‚ˆ๏ผŸ\nใฉใ†ใ—ใŸใฎ๏ผŸ ็†ฑใŒใ‚ใ‚‹ใฎ๏ผŸ\nใใ‚Œใจใ‚‚ๆ€’ใฃใฆใ‚‹ใฎ๏ผŸ\nใญใ‡ใ€่จ€ใฃใฆใ‚ˆ๏ผ\n\nใฉใ†ใ—ใฆใใ‚“ใช็›ฎใง่ฆ‹ใ‚‹ใฎ๏ผŸ\n็งใ€ๆ‚ชใ„ใ“ใจใ—ใŸ๏ผŸ\nไฝ•ใ‹้–“้•ใˆใŸใฎ๏ผŸ\nใŠ้ก˜ใ„ใ€ใ‚„ใ‚ใฆโ€ฆ ๆ€–ใ„ใ‹ใ‚‰โ€ฆ\nใ ใ‹ใ‚‰ใ€ใ‚„ใ‚ใฆใ‚ˆโ€ฆ\n\n[Bridge]\n็›ฎใ‚’้–‰ใ˜ใฆใ€ใใ‚‹ใฃใจ่ƒŒใ‚’ๅ‘ใ‘ใฆใ€\nไฝ•ใ‚‚่ฆ‹ใชใ‹ใฃใŸใƒ•ใƒชใ™ใ‚‹ใ‹ใ‚‰ใ€\nๆ€’ใ‚‰ใชใ„ใงโ€ฆ ่จฑใ—ใฆใ‚ˆโ€ฆ\n\n[Chorus]\nใญใ‡ใ€้ก”ใŒ่ตคใ„ใ‚ˆ๏ผŸ\nใฉใ†ใ—ใŸใฎ๏ผŸ ็†ฑใŒใ‚ใ‚‹ใฎ๏ผŸ\nใใ‚Œใจใ‚‚ๆ€’ใฃใฆใ‚‹ใฎ๏ผŸ\nใญใ‡ใ€่จ€ใฃใฆใ‚ˆ๏ผ\n\nใฉใ†ใ—ใฆใใ‚“ใช็›ฎใง่ฆ‹ใ‚‹ใฎ๏ผŸ\n็งใ€ๆ‚ชใ„ใ“ใจใ—ใŸ๏ผŸ\nไฝ•ใ‹้–“้•ใˆใŸใฎ๏ผŸ\nใŠ้ก˜ใ„ใ€ใ‚„ใ‚ใฆโ€ฆ ๆ€–ใ„ใ‹ใ‚‰โ€ฆ\nใ ใ‹ใ‚‰ใ€ใ‚„ใ‚ใฆใ‚ˆโ€ฆ\n\n[Bridge 2]\nๅพ…ใฃใฆใ€ใ‚‚ใ—็งใŒๆ‚ชใ„ใชใ‚‰ใ€\nใ”ใ‚ใ‚“ใชใ•ใ„ใฃใฆ่จ€ใ†ใ‹ใ‚‰ใ€\nใ‚ขใ‚คใ‚นใ‚ฏใƒชใƒผใƒ ใ‚ใ’ใ‚‹ใ‹ใ‚‰ใ€\nใ‚‚ใ†ๆ€’ใ‚‰ใชใ„ใง๏ผŸ\n\nOooohโ€ฆ ่จ€ใฃใฆใ‚ˆ๏ผ",
4
+ "audio_duration": 160,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.0282442569732666,
19
+ "diffusion": 12.104875326156616,
20
+ "latent2audio": 1.587641954421997
21
+ },
22
+ "actual_seeds": [
23
+ 4028738662
24
+ ]
25
+ }
examples/zh_rap_lora/input_params/output_20250512120348_0_input_params.json CHANGED
@@ -22,7 +22,7 @@
22
  "latent2audio": 0.5694489479064941
23
  },
24
  "actual_seeds": [
25
- 721655639
26
  ],
27
  "retake_seeds": [
28
  1603201617
 
22
  "latent2audio": 0.5694489479064941
23
  },
24
  "actual_seeds": [
25
+ 226581098
26
  ],
27
  "retake_seeds": [
28
  1603201617
examples/zh_rap_lora/input_params/output_20250512160830_0_input_params.json DELETED
@@ -1,45 +0,0 @@
1
- {
2
- "lora_name_or_path": "/root/sag_train/data/ace_step_v1_chinese_rap_lora_80k",
3
- "task": "text2music",
4
- "prompt": "articulate, spoken word, young adult, rap music, male, clear, energetic, warm, relaxed, breathy, night club",
5
- "lyrics": "[verse]\n่ฟ™ ่ฟ™ ่ฐ ๅˆ ๅœจ ๆดพ ๅฏน ๅ– ๅคš\nๆˆ‘ ็š„ ่„‘ ่ข‹\nๅƒ ่ขซ ้ฉด ่ธข ่ฟ‡\nไธ ๅฏน ๅŠฒ\n่ˆŒ ๅคด ๆ‰“ ็ป“ ไธ ไผš ่ฏด\nไฝ  ๆฅ ๆŒ‘ ๆˆ˜ ๆˆ‘ ๅฐฑ ่ทช\nๅผ€ ๅฑ€ ็›ด ๆŽฅ ๅดฉ ๆบƒ\n\n[chorus]\nๅฐฑ ๅ’ช ไนฑ ๅ’ช ๅฟต ๅ’ช ้”™ ๅ’ช\nๅ˜ด ๅ’ช ็“ข ๅ’ช ๆˆ ๅ’ช ็‹— ๅ’ช\n่„‘ ๅ’ช ่ข‹ ๅ’ช ๅƒ ๅ’ช ๆต† ๅ’ช ็ณŠ ๅ’ช\n่ทŸ ๅ’ช ็€ ๅ’ช ่Š‚ ๅ’ช ๅฅ ๅ’ช\nๆŠŠ ๅ’ช ๆญŒ ๅ’ช ่ฏ ๅ’ช ๅ…จ ๅ’ช ๅฟ˜ ๅ’ช\nไธ€ ๅ’ช ๅผ  ๅ’ช ๅ˜ด ๅ’ช ๅฐฑ ๅ’ช ๅบŸ ๅ’ช\nๅช ๅ’ช ๅ‰ฉ ๅ’ช ไธ‹ ๅ’ช ๅฐด ๅ’ช ๅฐฌ ๅ’ช ๅ›ž ๅ’ช ๅฟ†\n่‰๏ผ\n\n[verse]\n้”™ ้”™ ้”™ ้”™ ไบ†\nไธ€ ๅฃ ๆฐ” ๅ…จ ๅฟต ้”™\n้”™ ้”™ ้”™ ้”™ ไบ†\n่ˆŒ ๅคด ๆ‰“ ็ป“ ็”ฉ ้”…\n็”ฉ ็”ฉ ็”ฉ ็”ฉ ้”…\n็”ฉ ้”… ็”ฉ ้”…\nๆ‹ ๅญ ๅ…จ ้ƒจ ไนฑ ๅฅ—\n่ง‚ ไผ— ็ฌ‘ ๅˆฐ ๅ ่ก€\n\n[verse]\nไฝ  ็š„ ๆญŒ ่ฏ ๆˆ‘ ็š„ ๅ™ฉ ๆขฆ\nๅ”ฑ ๅฎŒ ็›ด ๆŽฅ ็คพ ๆญป\n่ฐƒ ่ท‘ ๅˆฐ ๅค– ๅคช ็ฉบ\n่ง‚ ไผ— ่กจ ๆƒ… ่ฃ‚ ๅผ€\nไฝ  ็ฌ‘ ๆˆ‘ ่œ\nๆˆ‘ ็ฌ‘ ไฝ  ไธ ๆ‡‚\n่ฟ™ ๅซ ่‰บ ๆœฏ ่กจ ๆผ”\nไธ ๆœ ไฝ  ๆฅ๏ผ\n\n[verse]\n่ฟ™ ่ฟ™ ่ฐ ๅˆ ๅœจ ๆดพ ๅฏน ไธข ไบบ\nๆˆ‘ ็š„ ไธ– ็•Œ\nๅทฒ ็ป ๅฝป ๅบ• ๅดฉ ๆบƒ\nๆฒก ๆœ‰ ๅฎŒ ็พŽ\nๅช ๆœ‰ ็ฟป ่ฝฆ ็Žฐ ๅœบ\nไปฅ ๅŠ ่ง‚ ไผ— ็š„ ๅ˜ฒ ่ฎฝ\n\n[chorus]\nๅฐฑ ๅ’ช ไนฑ ๅ’ช ๅฟต ๅ’ช ้”™ ๅ’ช\nๅ˜ด ๅ’ช ็“ข ๅ’ช ๆˆ ๅ’ช ็‹— ๅ’ช\n่„‘ ๅ’ช ่ข‹ ๅ’ช ๅƒ ๅ’ช ๆต† ๅ’ช ็ณŠ ๅ’ช\n่ทŸ ๅ’ช ็€ ๅ’ช ่Š‚ ๅ’ช ๅฅ ๅ’ช\nๆŠŠ ๅ’ช ๆญŒ ๅ’ช ่ฏ ๅ’ช ๅ…จ ๅ’ช ๅฟ˜ ๅ’ช\nไธ€ ๅ’ช ๅผ  ๅ’ช ๅ˜ด ๅ’ช ๅฐฑ ๅ’ช ๅบŸ ๅ’ช\nๅช ๅ’ช ๅ‰ฉ ๅ’ช ไธ‹ ๅ’ช ๅฐด ๅ’ช ๅฐฌ ๅ’ช ๅ›ž ๅ’ช ๅฟ†\n่‰๏ผ\n\n[verse]\n้”™ ้”™ ้”™ ้”™ ไบ†\nไธ€ ๅฃ ๆฐ” ๅ…จ ๅฟต ้”™\n้”™ ้”™ ้”™ ้”™ ไบ†\n่ˆŒ ๅคด ๆ‰“ ็ป“ ็”ฉ ้”…\n็”ฉ ็”ฉ ็”ฉ ็”ฉ ้”…\n็”ฉ ้”… ็”ฉ ้”…\nๆ‹ ๅญ ๅ…จ ้ƒจ ไนฑ ๅฅ—\n่ง‚ ไผ— ็ฌ‘ ๅˆฐ ๅ ่ก€\n\n[verse]\nไฝ  ็š„ ๆญŒ ่ฏ ๆˆ‘ ็š„ ๅ™ฉ ๆขฆ\nๅ”ฑ ๅฎŒ ็›ด ๆŽฅ ็คพ ๆญป\n่ฐƒ ่ท‘ ๅˆฐ ๅค– ๅคช ็ฉบ\n่ง‚ ไผ— ่กจ ๆƒ… ่ฃ‚ ๅผ€\nไฝ  ็ฌ‘ ๆˆ‘ ่œ\nๆˆ‘ ็ฌ‘ ไฝ  ไธ ๆ‡‚\n่ฟ™ ๅซ ่‰บ ๆœฏ ่กจ ๆผ”\nไธ ๆœ ไฝ  ๆฅ๏ผ",
6
- "audio_duration": 169.12,
7
- "infer_step": 60,
8
- "guidance_scale": 15,
9
- "scheduler_type": "euler",
10
- "cfg_type": "apg",
11
- "omega_scale": 10,
12
- "guidance_interval": 0.5,
13
- "guidance_interval_decay": 0,
14
- "min_guidance_scale": 3,
15
- "use_erg_tag": true,
16
- "use_erg_lyric": false,
17
- "use_erg_diffusion": true,
18
- "oss_steps": [],
19
- "timecosts": {
20
- "preprocess": 0.041605472564697266,
21
- "diffusion": 14.009192705154419,
22
- "latent2audio": 1.55946946144104
23
- },
24
- "actual_seeds": [
25
- 547563805
26
- ],
27
- "retake_seeds": [
28
- 2702917060
29
- ],
30
- "retake_variance": 0.5,
31
- "guidance_scale_text": 0,
32
- "guidance_scale_lyric": 0,
33
- "repaint_start": 0,
34
- "repaint_end": 0,
35
- "edit_n_min": 0.0,
36
- "edit_n_max": 1.0,
37
- "edit_n_avg": 1,
38
- "src_audio_path": null,
39
- "edit_target_prompt": null,
40
- "edit_target_lyrics": null,
41
- "audio2audio_enable": false,
42
- "ref_audio_strength": 0.5,
43
- "ref_audio_input": null,
44
- "audio_path": "./outputs/output_20250512160830_0.wav"
45
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_ace_step.py CHANGED
@@ -12,9 +12,15 @@ import math
12
  from huggingface_hub import hf_hub_download, snapshot_download
13
 
14
  # from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
- from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
16
- from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
17
- from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
 
 
 
 
 
 
18
  from diffusers.utils.torch_utils import randn_tensor
19
  from transformers import UMT5EncoderModel, AutoTokenizer
20
 
@@ -22,23 +28,42 @@ from language_segmentation import LangSegment
22
  from music_dcae.music_dcae_pipeline import MusicDCAE
23
  from models.ace_step_transformer import ACEStepTransformer2DModel
24
  from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
25
- from apg_guidance import apg_forward, MomentumBuffer, cfg_forward, cfg_zero_star, cfg_double_condition_forward
 
 
 
 
 
 
26
  import torchaudio
27
  import torio
28
 
29
 
30
  torch.backends.cudnn.benchmark = False
31
- torch.set_float32_matmul_precision('high')
32
  torch.backends.cudnn.deterministic = True
33
  torch.backends.cuda.matmul.allow_tf32 = True
34
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
 
36
 
37
  SUPPORT_LANGUAGES = {
38
- "en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
39
- "pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
40
- "nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
41
- "ko": 6152, "hi": 6680
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
43
 
44
  structure_pattern = re.compile(r"\[.*?\]")
@@ -56,7 +81,16 @@ REPO_ID = "ACE-Step/ACE-Step-v1-3.5B"
56
  # class ACEStepPipeline(DiffusionPipeline):
57
  class ACEStepPipeline:
58
 
59
- def __init__(self, checkpoint_dir=None, device_id=0, dtype="bfloat16", text_encoder_checkpoint_path=None, persistent_storage_path=None, torch_compile=False, **kwargs):
 
 
 
 
 
 
 
 
 
60
  if not checkpoint_dir:
61
  if persistent_storage_path is None:
62
  checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
@@ -64,7 +98,11 @@ class ACEStepPipeline:
64
  checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints")
65
  ensure_directory_exists(checkpoint_dir)
66
  self.checkpoint_dir = checkpoint_dir
67
- device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else torch.device("cpu")
 
 
 
 
68
  if device.type == "cpu" and torch.backends.mps.is_available():
69
  device = torch.device("mps")
70
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
@@ -74,17 +112,25 @@ class ACEStepPipeline:
74
  self.loaded = False
75
  self.torch_compile = torch_compile
76
  self.lora_path = "none"
77
-
78
  def load_lora(self, lora_name_or_path):
79
  if lora_name_or_path != self.lora_path and lora_name_or_path != "none":
80
  if not os.path.exists(lora_name_or_path):
81
- lora_download_path = snapshot_download(lora_name_or_path, cache_dir=self.checkpoint_dir)
 
 
82
  else:
83
  lora_download_path = lora_name_or_path
84
  if self.lora_path != "none":
85
  self.ace_step_transformer.unload_lora()
86
- self.ace_step_transformer.load_lora_adapter(os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"), adapter_name="zh_rap_lora", with_alpha=True)
87
- logger.info(f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path}")
 
 
 
 
 
 
88
  self.lora_path = lora_name_or_path
89
  elif self.lora_path != "none" and lora_name_or_path == "none":
90
  logger.info("No lora weights to load.")
@@ -99,55 +145,124 @@ class ACEStepPipeline:
99
  text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
100
 
101
  files_exist = (
102
- os.path.exists(os.path.join(dcae_model_path, "config.json")) and
103
- os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
104
- os.path.exists(os.path.join(vocoder_model_path, "config.json")) and
105
- os.path.exists(os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors")) and
106
- os.path.exists(os.path.join(ace_step_model_path, "config.json")) and
107
- os.path.exists(os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors")) and
108
- os.path.exists(os.path.join(text_encoder_model_path, "config.json")) and
109
- os.path.exists(os.path.join(text_encoder_model_path, "model.safetensors")) and
110
- os.path.exists(os.path.join(text_encoder_model_path, "special_tokens_map.json")) and
111
- os.path.exists(os.path.join(text_encoder_model_path, "tokenizer_config.json")) and
112
- os.path.exists(os.path.join(text_encoder_model_path, "tokenizer.json"))
 
 
 
 
 
 
 
 
 
 
 
 
113
  )
114
 
115
  if not files_exist:
116
- logger.info(f"Checkpoint directory {checkpoint_dir} is not complete, downloading from Hugging Face Hub")
 
 
117
 
118
  # download music dcae model
119
  os.makedirs(dcae_model_path, exist_ok=True)
120
- hf_hub_download(repo_id=REPO_ID, subfolder="music_dcae_f8c8",
121
- filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
122
- hf_hub_download(repo_id=REPO_ID, subfolder="music_dcae_f8c8",
123
- filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
 
 
 
 
 
 
 
 
 
 
124
 
125
  # download vocoder model
126
  os.makedirs(vocoder_model_path, exist_ok=True)
127
- hf_hub_download(repo_id=REPO_ID, subfolder="music_vocoder",
128
- filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
129
- hf_hub_download(repo_id=REPO_ID, subfolder="music_vocoder",
130
- filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
 
 
 
 
 
 
 
 
 
 
131
 
132
  # download ace_step transformer model
133
  os.makedirs(ace_step_model_path, exist_ok=True)
134
- hf_hub_download(repo_id=REPO_ID, subfolder="ace_step_transformer",
135
- filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
136
- hf_hub_download(repo_id=REPO_ID, subfolder="ace_step_transformer",
137
- filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
 
 
 
 
 
 
 
 
 
 
138
 
139
  # download text encoder model
140
  os.makedirs(text_encoder_model_path, exist_ok=True)
141
- hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
142
- filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
143
- hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
144
- filename="model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
145
- hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
146
- filename="special_tokens_map.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
147
- hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
148
- filename="tokenizer_config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
149
- hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
150
- filename="tokenizer.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  logger.info("Models downloaded")
153
 
@@ -156,29 +271,131 @@ class ACEStepPipeline:
156
  ace_step_checkpoint_path = ace_step_model_path
157
  text_encoder_checkpoint_path = text_encoder_model_path
158
 
159
- self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
 
 
 
160
  self.music_dcae.to(device).eval().to(self.dtype)
161
 
162
- self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype)
 
 
163
  self.ace_step_transformer.to(device).eval().to(self.dtype)
164
 
165
  lang_segment = LangSegment()
166
 
167
- lang_segment.setfilters([
168
- 'af', 'am', 'an', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'dz', 'el',
169
- 'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'ga', 'gl', 'gu', 'he', 'hi', 'hr', 'ht', 'hu', 'hy',
170
- 'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lb', 'lo', 'lt', 'lv', 'mg',
171
- 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'nb', 'ne', 'nl', 'nn', 'no', 'oc', 'or', 'pa', 'pl', 'ps', 'pt', 'qu',
172
- 'ro', 'ru', 'rw', 'se', 'si', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ug', 'uk',
173
- 'ur', 'vi', 'vo', 'wa', 'xh', 'zh', 'zu'
174
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  self.lang_segment = lang_segment
176
  self.lyric_tokenizer = VoiceBpeTokenizer()
177
- text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval()
 
 
178
  text_encoder_model = text_encoder_model.to(device).to(self.dtype)
179
  text_encoder_model.requires_grad_(False)
180
  self.text_encoder_model = text_encoder_model
181
- self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
 
 
182
  self.loaded = True
183
 
184
  # compile
@@ -188,7 +405,13 @@ class ACEStepPipeline:
188
  self.text_encoder_model = torch.compile(self.text_encoder_model)
189
 
190
  def get_text_embeddings(self, texts, device, text_max_length=256):
191
- inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
 
 
 
 
 
 
192
  inputs = {key: value.to(device) for key, value in inputs.items()}
193
  if self.text_encoder_model.device != device:
194
  self.text_encoder_model.to(device)
@@ -197,62 +420,87 @@ class ACEStepPipeline:
197
  last_hidden_states = outputs.last_hidden_state
198
  attention_mask = inputs["attention_mask"]
199
  return last_hidden_states, attention_mask
200
-
201
- def get_text_embeddings_null(self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10):
202
- inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
 
 
 
 
 
 
 
 
203
  inputs = {key: value.to(device) for key, value in inputs.items()}
204
  if self.text_encoder_model.device != device:
205
  self.text_encoder_model.to(device)
206
-
207
  def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
208
  handlers = []
209
-
210
  def hook(module, input, output):
211
  output[:] *= tau
212
  return output
213
-
214
  for i in range(l_min, l_max):
215
- handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
 
 
 
 
216
  handlers.append(handler)
217
-
218
  with torch.no_grad():
219
  outputs = self.text_encoder_model(**inputs)
220
  last_hidden_states = outputs.last_hidden_state
221
-
222
  for hook in handlers:
223
  hook.remove()
224
-
225
  return last_hidden_states
226
-
227
  last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
228
  return last_hidden_states
229
 
230
  def set_seeds(self, batch_size, manual_seeds=None):
231
- seeds = None
232
  if manual_seeds is not None:
233
  if isinstance(manual_seeds, str):
234
  if "," in manual_seeds:
235
- seeds = list(map(int, manual_seeds.split(",")))
236
  elif manual_seeds.isdigit():
237
- seeds = int(manual_seeds)
238
-
239
- random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
 
 
 
 
 
 
 
 
240
  actual_seeds = []
241
  for i in range(batch_size):
242
- seed = None
243
- if seeds is None:
244
- seed = torch.randint(0, 2**32, (1,)).item()
245
- if isinstance(seeds, int):
246
- seed = seeds
247
- if isinstance(seeds, list):
248
- seed = seeds[i]
249
- random_generators[i].manual_seed(seed)
250
- actual_seeds.append(seed)
 
 
 
 
 
251
  return random_generators, actual_seeds
252
 
253
  def get_lang(self, text):
254
  language = "en"
255
- try:
256
  _ = self.lang_segment.getTexts(text)
257
  langCounts = self.lang_segment.getCounts()
258
  language = langCounts[0][0]
@@ -286,7 +534,9 @@ class ACEStepPipeline:
286
  else:
287
  token_idx = self.lyric_tokenizer.encode(line, lang)
288
  if debug:
289
- toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
 
 
290
  logger.info(f"debbug {line} --> {lang} --> {toks}")
291
  lyric_token_idx = lyric_token_idx + token_idx + [2]
292
  except Exception as e:
@@ -315,11 +565,13 @@ class ACEStepPipeline:
315
  attention_mask=None,
316
  momentum_buffer=None,
317
  momentum_buffer_tar=None,
318
- return_src_pred=True
319
  ):
320
  noise_pred_src = None
321
  if return_src_pred:
322
- src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src
 
 
323
  timestep = t.expand(src_latent_model_input.shape[0])
324
  # source
325
  noise_pred_src = self.ace_step_transformer(
@@ -334,7 +586,9 @@ class ACEStepPipeline:
334
  ).sample
335
 
336
  if do_classifier_free_guidance:
337
- noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2)
 
 
338
  if cfg_type == "apg":
339
  noise_pred_src = apg_forward(
340
  pred_cond=noise_pred_with_cond_src,
@@ -349,7 +603,9 @@ class ACEStepPipeline:
349
  cfg_strength=guidance_scale,
350
  )
351
 
352
- tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar
 
 
353
  timestep = t.expand(tar_latent_model_input.shape[0])
354
  # target
355
  noise_pred_tar = self.ace_step_transformer(
@@ -419,26 +675,52 @@ class ACEStepPipeline:
419
  T_steps = infer_steps
420
  frame_length = src_latents.shape[-1]
421
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
422
-
423
- timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, device, timesteps=None)
 
 
424
 
425
  if do_classifier_free_guidance:
426
  attention_mask = torch.cat([attention_mask] * 2, dim=0)
427
-
428
- encoder_text_hidden_states = torch.cat([encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states)], 0)
 
 
 
 
 
 
429
  text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
430
 
431
- target_encoder_text_hidden_states = torch.cat([target_encoder_text_hidden_states, torch.zeros_like(target_encoder_text_hidden_states)], 0)
432
- target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0)
 
 
 
 
 
 
 
 
433
 
434
- speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
435
- target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0)
 
 
 
 
436
 
437
- lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
 
 
438
  lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
439
 
440
- target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0)
441
- target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0)
 
 
 
 
442
 
443
  momentum_buffer = MomentumBuffer()
444
  momentum_buffer_tar = MomentumBuffer()
@@ -455,10 +737,10 @@ class ACEStepPipeline:
455
  if i < n_min:
456
  continue
457
 
458
- t_i = t/1000
459
 
460
- if i+1 < len(timesteps):
461
- t_im1 = (timesteps[i+1])/1000
462
  else:
463
  t_im1 = torch.zeros_like(t_i).to(t_i.device)
464
 
@@ -466,7 +748,12 @@ class ACEStepPipeline:
466
  # Calculate the average of the V predictions
467
  V_delta_avg = torch.zeros_like(x_src)
468
  for k in range(n_avg):
469
- fwd_noise = randn_tensor(shape=x_src.shape, generator=random_generators, device=device, dtype=dtype)
 
 
 
 
 
470
 
471
  zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise
472
 
@@ -490,22 +777,29 @@ class ACEStepPipeline:
490
  guidance_scale=guidance_scale,
491
  target_guidance_scale=target_guidance_scale,
492
  attention_mask=attention_mask,
493
- momentum_buffer=momentum_buffer
494
  )
495
- V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src) # - (hfg-1)*( x_src))
 
 
496
 
497
  # propagate direct ODE
498
  zt_edit = zt_edit.to(torch.float32)
499
  zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
500
  zt_edit = zt_edit.to(V_delta_avg.dtype)
501
- else: # i >= T_steps-n_min # regular sampling for last n_min steps
502
  if i == n_max:
503
- fwd_noise = randn_tensor(shape=x_src.shape, generator=random_generators, device=device, dtype=dtype)
 
 
 
 
 
504
  scheduler._init_step_index(t)
505
  sigma = scheduler.sigmas[scheduler.step_index]
506
  xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src
507
  xt_tar = zt_edit + xt_src - x_src
508
-
509
  _, Vt_tar = self.calc_v(
510
  zt_src=None,
511
  zt_tar=xt_tar,
@@ -527,13 +821,13 @@ class ACEStepPipeline:
527
  momentum_buffer_tar=momentum_buffer_tar,
528
  return_src_pred=False,
529
  )
530
-
531
  dtype = Vt_tar.dtype
532
  xt_tar = xt_tar.to(torch.float32)
533
  prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar
534
- prev_sample = prev_sample.to(dtype)
535
  xt_tar = prev_sample
536
-
537
  target_latents = zt_edit if xt_tar is None else xt_tar
538
  return target_latents
539
 
@@ -551,7 +845,12 @@ class ACEStepPipeline:
551
  timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
552
  indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
553
  nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
554
- sigma = scheduler.sigmas[nearest_idx].flatten().to(gt_latents.device).to(gt_latents.dtype)
 
 
 
 
 
555
  while len(sigma.shape) < gt_latents.ndim:
556
  sigma = sigma.unsqueeze(-1)
557
  noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
@@ -595,15 +894,30 @@ class ACEStepPipeline:
595
  ref_latents=None,
596
  ):
597
 
598
- logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
 
 
 
 
599
  do_classifier_free_guidance = True
600
  if guidance_scale == 0.0 or guidance_scale == 1.0:
601
  do_classifier_free_guidance = False
602
-
603
  do_double_condition_guidance = False
604
- if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
 
 
 
 
 
605
  do_double_condition_guidance = True
606
- logger.info("do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(do_double_condition_guidance, guidance_scale_text, guidance_scale_lyric))
 
 
 
 
 
 
607
 
608
  device = encoder_text_hidden_states.device
609
  dtype = encoder_text_hidden_states.dtype
@@ -619,7 +933,7 @@ class ACEStepPipeline:
619
  num_train_timesteps=1000,
620
  shift=3.0,
621
  )
622
-
623
  frame_length = int(duration * 44100 / 512 / 8)
624
  if src_latents is not None:
625
  frame_length = src_latents.shape[-1]
@@ -630,31 +944,60 @@ class ACEStepPipeline:
630
  if len(oss_steps) > 0:
631
  infer_steps = max(oss_steps)
632
  scheduler.set_timesteps
633
- timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
 
 
 
 
 
634
  new_timesteps = torch.zeros(len(oss_steps), dtype=dtype, device=device)
635
  for idx in range(len(oss_steps)):
636
- new_timesteps[idx] = timesteps[oss_steps[idx]-1]
637
  num_inference_steps = len(oss_steps)
638
  sigmas = (new_timesteps / 1000).float().cpu().numpy()
639
- timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=num_inference_steps, device=device, sigmas=sigmas)
640
- logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
 
 
 
 
 
 
 
641
  else:
642
- timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
643
-
644
- target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
645
-
 
 
 
 
 
 
 
 
 
 
646
  is_repaint = False
647
- is_extend = False
648
  if add_retake_noise:
649
  n_min = int(infer_steps * (1 - retake_variance))
650
- retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
651
- retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
 
 
 
 
 
 
 
652
  repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
653
  repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
654
  x0 = src_latents
655
  # retake
656
- is_repaint = (repaint_end_frame - repaint_start_frame != frame_length)
657
-
658
  is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
659
  if is_extend:
660
  is_repaint = True
@@ -662,13 +1005,23 @@ class ACEStepPipeline:
662
  # TODO: train a mask aware repainting controlnet
663
  # to make sure mean = 0, std = 1
664
  if not is_repaint:
665
- target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 
 
 
666
  elif not is_extend:
667
- # if repaint_end_frame
668
- repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
 
 
669
  repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
670
- repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
671
- repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
 
 
 
 
 
672
  zt_edit = x0.clone()
673
  z0 = repaint_noise
674
  elif is_extend:
@@ -684,73 +1037,107 @@ class ACEStepPipeline:
684
  if repaint_start_frame < 0:
685
  left_pad_frame_length = abs(repaint_start_frame)
686
  frame_length = left_pad_frame_length + gt_latents.shape[-1]
687
- extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
 
 
688
  if frame_length > max_infer_fame_length:
689
  right_trim_length = frame_length - max_infer_fame_length
690
- extend_gt_latents = extend_gt_latents[:,:,:,:max_infer_fame_length]
691
- to_right_pad_gt_latents = extend_gt_latents[:,:,:,-right_trim_length:]
 
 
 
 
692
  frame_length = max_infer_fame_length
693
  repaint_start_frame = 0
694
  gt_latents = extend_gt_latents
695
-
696
  if repaint_end_frame > src_latents_length:
697
  right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
698
  frame_length = gt_latents.shape[-1] + right_pad_frame_length
699
- extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
 
 
700
  if frame_length > max_infer_fame_length:
701
  left_trim_length = frame_length - max_infer_fame_length
702
- extend_gt_latents = extend_gt_latents[:,:,:,-max_infer_fame_length:]
703
- to_left_pad_gt_latents = extend_gt_latents[:,:,:,:left_trim_length]
 
 
 
 
704
  frame_length = max_infer_fame_length
705
  repaint_end_frame = frame_length
706
  gt_latents = extend_gt_latents
707
 
708
- repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
 
 
709
  if left_pad_frame_length > 0:
710
- repaint_mask[:,:,:,:left_pad_frame_length] = 1.0
711
  if right_pad_frame_length > 0:
712
- repaint_mask[:,:,:,-right_pad_frame_length:] = 1.0
713
  x0 = gt_latents
714
  padd_list = []
715
  if left_pad_frame_length > 0:
716
  padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
717
- padd_list.append(target_latents[:,:,:,left_trim_length:target_latents.shape[-1]-right_trim_length])
 
 
 
 
 
 
 
718
  if right_pad_frame_length > 0:
719
  padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
720
  target_latents = torch.cat(padd_list, dim=-1)
721
- assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
 
 
722
  zt_edit = x0.clone()
723
  z0 = target_latents
724
 
725
  init_timestep = 1000
726
  if audio2audio_enable and ref_latents is not None:
727
- target_latents, init_timestep = self.add_latents_noise(gt_latents=ref_latents, variance=(1-ref_audio_strength), noise=target_latents, scheduler=scheduler)
 
 
 
 
 
728
 
729
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
730
-
731
  # guidance interval
732
  start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
733
  end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
734
- logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
 
 
735
 
736
  momentum_buffer = MomentumBuffer()
737
 
738
  def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
739
  handlers = []
740
-
741
  def hook(module, input, output):
742
  output[:] *= tau
743
  return output
744
-
745
  for i in range(l_min, l_max):
746
- handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
 
 
747
  handlers.append(handler)
748
-
749
- encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
750
-
 
 
751
  for hook in handlers:
752
  hook.remove()
753
-
754
  return encoder_hidden_states
755
 
756
  # P(speaker, text, lyric)
@@ -767,12 +1154,16 @@ class ACEStepPipeline:
767
  encoder_hidden_states_null = forward_encoder_with_temperature(
768
  self,
769
  inputs={
770
- "encoder_text_hidden_states": encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states),
 
 
 
 
771
  "text_attention_mask": text_attention_mask,
772
  "speaker_embeds": torch.zeros_like(speaker_embds),
773
  "lyric_token_idx": lyric_token_ids,
774
  "lyric_mask": lyric_mask,
775
- }
776
  )
777
  else:
778
  # P(null_speaker, null_text, null_lyric)
@@ -783,7 +1174,7 @@ class ACEStepPipeline:
783
  torch.zeros_like(lyric_token_ids),
784
  lyric_mask,
785
  )
786
-
787
  encoder_hidden_states_no_lyric = None
788
  if do_double_condition_guidance:
789
  # P(null_speaker, text, lyric_weaker)
@@ -796,7 +1187,7 @@ class ACEStepPipeline:
796
  "speaker_embeds": torch.zeros_like(speaker_embds),
797
  "lyric_token_idx": lyric_token_ids,
798
  "lyric_mask": lyric_mask,
799
- }
800
  )
801
  # P(null_speaker, text, no_lyric)
802
  else:
@@ -808,26 +1199,34 @@ class ACEStepPipeline:
808
  lyric_mask,
809
  )
810
 
811
- def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
 
 
812
  handlers = []
813
-
814
  def hook(module, input, output):
815
  output[:] *= tau
816
  return output
817
-
818
  for i in range(l_min, l_max):
819
- handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
 
 
820
  handlers.append(handler)
821
- handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook)
 
 
822
  handlers.append(handler)
823
 
824
- sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
825
-
 
 
826
  for hook in handlers:
827
  hook.remove()
828
-
829
  return sample
830
-
831
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
832
 
833
  if t > init_timestep:
@@ -850,8 +1249,15 @@ class ACEStepPipeline:
850
  # compute current guidance scale
851
  if guidance_interval_decay > 0:
852
  # Linearly interpolate to calculate the current guidance scale
853
- progress = (i - start_idx) / (end_idx - start_idx - 1) # ๅฝ’ไธ€ๅŒ–ๅˆฐ[0,1]
854
- current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay
 
 
 
 
 
 
 
855
  else:
856
  current_guidance_scale = guidance_scale
857
 
@@ -869,7 +1275,10 @@ class ACEStepPipeline:
869
  ).sample
870
 
871
  noise_pred_with_only_text_cond = None
872
- if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None:
 
 
 
873
  noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
874
  hidden_states=latent_model_input,
875
  attention_mask=attention_mask,
@@ -901,7 +1310,10 @@ class ACEStepPipeline:
901
  timestep=timestep,
902
  ).sample
903
 
904
- if do_double_condition_guidance and noise_pred_with_only_text_cond is not None:
 
 
 
905
  noise_pred = cfg_double_condition_forward(
906
  cond_output=noise_pred_with_cond,
907
  uncond_output=noise_pred_uncond,
@@ -930,7 +1342,7 @@ class ACEStepPipeline:
930
  guidance_scale=current_guidance_scale,
931
  i=i,
932
  zero_steps=zero_steps,
933
- use_zero_init=use_zero_init
934
  )
935
  else:
936
  latent_model_input = latents
@@ -945,9 +1357,9 @@ class ACEStepPipeline:
945
  ).sample
946
 
947
  if is_repaint and i >= n_min:
948
- t_i = t/1000
949
- if i+1 < len(timesteps):
950
- t_im1 = (timesteps[i+1])/1000
951
  else:
952
  t_im1 = torch.zeros_like(t_i).to(t_i.device)
953
  dtype = noise_pred.dtype
@@ -956,18 +1368,37 @@ class ACEStepPipeline:
956
  prev_sample = prev_sample.to(dtype)
957
  target_latents = prev_sample
958
  zt_src = (1 - t_im1) * x0 + (t_im1) * z0
959
- target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
 
 
960
  else:
961
- target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
 
 
 
 
 
 
962
 
963
  if is_extend:
964
  if to_right_pad_gt_latents is not None:
965
- target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1)
 
 
966
  if to_left_pad_gt_latents is not None:
967
- target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0)
 
 
968
  return target_latents
969
 
970
- def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="mp3"):
 
 
 
 
 
 
 
971
  output_audio_paths = []
972
  bs = latents.shape[0]
973
  audio_lengths = [target_wav_duration_second * sample_rate] * bs
@@ -976,11 +1407,15 @@ class ACEStepPipeline:
976
  _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
977
  pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
978
  for i in tqdm(range(bs)):
979
- output_audio_path = self.save_wav_file(pred_wavs[i], i, sample_rate=sample_rate)
 
 
980
  output_audio_paths.append(output_audio_path)
981
  return output_audio_paths
982
 
983
- def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="mp3"):
 
 
984
  if save_path is None:
985
  logger.warning("save_path is None, using default path ./outputs/")
986
  base_path = f"./outputs"
@@ -989,9 +1424,17 @@ class ACEStepPipeline:
989
  base_path = save_path
990
  ensure_directory_exists(base_path)
991
 
992
- output_path_flac = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.{format}"
 
 
993
  target_wav = target_wav.float()
994
- torchaudio.save(output_path_flac, target_wav, sample_rate=sample_rate, format=format, compression=torio.io.CodecConfig(bit_rate=320000))
 
 
 
 
 
 
995
  return output_path_flac
996
 
997
  def infer_latents(self, input_audio_path):
@@ -1017,7 +1460,7 @@ class ACEStepPipeline:
1017
  omega_scale: int = 10.0,
1018
  manual_seeds: list = None,
1019
  guidance_interval: float = 0.5,
1020
- guidance_interval_decay: float = 0.,
1021
  min_guidance_scale: float = 3.0,
1022
  use_erg_tag: bool = True,
1023
  use_erg_lyric: bool = True,
@@ -1060,22 +1503,30 @@ class ACEStepPipeline:
1060
  start_time = time.time()
1061
 
1062
  random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
1063
- retake_random_generators, actual_retake_seeds = self.set_seeds(batch_size, retake_seeds)
 
 
1064
 
1065
  if isinstance(oss_steps, str) and len(oss_steps) > 0:
1066
  oss_steps = list(map(int, oss_steps.split(",")))
1067
  else:
1068
  oss_steps = []
1069
-
1070
  texts = [prompt]
1071
- encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
 
 
1072
  encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
1073
  text_attention_mask = text_attention_mask.repeat(batch_size, 1)
1074
 
1075
  encoder_text_hidden_states_null = None
1076
  if use_erg_tag:
1077
- encoder_text_hidden_states_null = self.get_text_embeddings_null(texts, self.device)
1078
- encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(batch_size, 1, 1)
 
 
 
 
1079
 
1080
  # not support for released checkpoint
1081
  speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
@@ -1086,8 +1537,18 @@ class ACEStepPipeline:
1086
  if len(lyrics) > 0:
1087
  lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug)
1088
  lyric_mask = [1] * len(lyric_token_idx)
1089
- lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1090
- lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
 
 
 
 
 
 
 
 
 
 
1091
 
1092
  if audio_duration <= 0:
1093
  audio_duration = random.uniform(30.0, 240.0)
@@ -1102,16 +1563,24 @@ class ACEStepPipeline:
1102
  if task == "retake":
1103
  repaint_start = 0
1104
  repaint_end = audio_duration
1105
-
1106
  src_latents = None
1107
  if src_audio_path is not None:
1108
- assert src_audio_path is not None and task in ("repaint", "edit", "extend"), "src_audio_path is required for retake/repaint/extend task"
1109
- assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
 
 
 
 
 
 
1110
  src_latents = self.infer_latents(src_audio_path)
1111
 
1112
  ref_latents = None
1113
  if ref_audio_input is not None and audio2audio_enable:
1114
- assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task"
 
 
1115
  assert os.path.exists(
1116
  ref_audio_input
1117
  ), f"ref_audio_input {ref_audio_input} does not exist"
@@ -1119,17 +1588,39 @@ class ACEStepPipeline:
1119
 
1120
  if task == "edit":
1121
  texts = [edit_target_prompt]
1122
- target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
1123
- target_encoder_text_hidden_states = target_encoder_text_hidden_states.repeat(batch_size, 1, 1)
1124
- target_text_attention_mask = target_text_attention_mask.repeat(batch_size, 1)
 
 
 
 
 
 
1125
 
1126
- target_lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1127
- target_lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
 
 
 
 
1128
  if len(edit_target_lyrics) > 0:
1129
- target_lyric_token_idx = self.tokenize_lyrics(edit_target_lyrics, debug=True)
 
 
1130
  target_lyric_mask = [1] * len(target_lyric_token_idx)
1131
- target_lyric_token_idx = torch.tensor(target_lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
1132
- target_lyric_mask = torch.tensor(target_lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
 
 
 
 
 
 
 
 
 
 
1133
 
1134
  target_speaker_embeds = speaker_embeds.clone()
1135
 
@@ -1145,7 +1636,7 @@ class ACEStepPipeline:
1145
  target_lyric_token_ids=target_lyric_token_idx,
1146
  target_lyric_mask=target_lyric_mask,
1147
  src_latents=src_latents,
1148
- random_generators=retake_random_generators, # more diversity
1149
  infer_steps=infer_step,
1150
  guidance_scale=guidance_scale,
1151
  n_min=edit_n_min,
@@ -1233,7 +1724,7 @@ class ACEStepPipeline:
1233
  "repaint_end": repaint_end,
1234
  "edit_n_min": edit_n_min,
1235
  "edit_n_max": edit_n_max,
1236
- "edit_n_avg": edit_n_avg,
1237
  "src_audio_path": src_audio_path,
1238
  "edit_target_prompt": edit_target_prompt,
1239
  "edit_target_lyrics": edit_target_lyrics,
@@ -1243,7 +1734,9 @@ class ACEStepPipeline:
1243
  }
1244
  # save input_params_json
1245
  for output_audio_path in output_paths:
1246
- input_params_json_save_path = output_audio_path.replace(f".{format}", "_input_params.json")
 
 
1247
  input_params_json["audio_path"] = output_audio_path
1248
  with open(input_params_json_save_path, "w", encoding="utf-8") as f:
1249
  json.dump(input_params_json, f, indent=4, ensure_ascii=False)
 
12
  from huggingface_hub import hf_hub_download, snapshot_download
13
 
14
  # from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from schedulers.scheduling_flow_match_euler_discrete import (
16
+ FlowMatchEulerDiscreteScheduler,
17
+ )
18
+ from schedulers.scheduling_flow_match_heun_discrete import (
19
+ FlowMatchHeunDiscreteScheduler,
20
+ )
21
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
22
+ retrieve_timesteps,
23
+ )
24
  from diffusers.utils.torch_utils import randn_tensor
25
  from transformers import UMT5EncoderModel, AutoTokenizer
26
 
 
28
  from music_dcae.music_dcae_pipeline import MusicDCAE
29
  from models.ace_step_transformer import ACEStepTransformer2DModel
30
  from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
31
+ from apg_guidance import (
32
+ apg_forward,
33
+ MomentumBuffer,
34
+ cfg_forward,
35
+ cfg_zero_star,
36
+ cfg_double_condition_forward,
37
+ )
38
  import torchaudio
39
  import torio
40
 
41
 
42
  torch.backends.cudnn.benchmark = False
43
+ torch.set_float32_matmul_precision("high")
44
  torch.backends.cudnn.deterministic = True
45
  torch.backends.cuda.matmul.allow_tf32 = True
46
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
 
48
 
49
  SUPPORT_LANGUAGES = {
50
+ "en": 259,
51
+ "de": 260,
52
+ "fr": 262,
53
+ "es": 284,
54
+ "it": 285,
55
+ "pt": 286,
56
+ "pl": 294,
57
+ "tr": 295,
58
+ "ru": 267,
59
+ "cs": 293,
60
+ "nl": 297,
61
+ "ar": 5022,
62
+ "zh": 5023,
63
+ "ja": 5412,
64
+ "hu": 5753,
65
+ "ko": 6152,
66
+ "hi": 6680,
67
  }
68
 
69
  structure_pattern = re.compile(r"\[.*?\]")
 
81
  # class ACEStepPipeline(DiffusionPipeline):
82
  class ACEStepPipeline:
83
 
84
+ def __init__(
85
+ self,
86
+ checkpoint_dir=None,
87
+ device_id=0,
88
+ dtype="bfloat16",
89
+ text_encoder_checkpoint_path=None,
90
+ persistent_storage_path=None,
91
+ torch_compile=False,
92
+ **kwargs,
93
+ ):
94
  if not checkpoint_dir:
95
  if persistent_storage_path is None:
96
  checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
 
98
  checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints")
99
  ensure_directory_exists(checkpoint_dir)
100
  self.checkpoint_dir = checkpoint_dir
101
+ device = (
102
+ torch.device(f"cuda:{device_id}")
103
+ if torch.cuda.is_available()
104
+ else torch.device("cpu")
105
+ )
106
  if device.type == "cpu" and torch.backends.mps.is_available():
107
  device = torch.device("mps")
108
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
 
112
  self.loaded = False
113
  self.torch_compile = torch_compile
114
  self.lora_path = "none"
115
+
116
  def load_lora(self, lora_name_or_path):
117
  if lora_name_or_path != self.lora_path and lora_name_or_path != "none":
118
  if not os.path.exists(lora_name_or_path):
119
+ lora_download_path = snapshot_download(
120
+ lora_name_or_path, cache_dir=self.checkpoint_dir
121
+ )
122
  else:
123
  lora_download_path = lora_name_or_path
124
  if self.lora_path != "none":
125
  self.ace_step_transformer.unload_lora()
126
+ self.ace_step_transformer.load_lora_adapter(
127
+ os.path.join(lora_download_path, "pytorch_lora_weights.safetensors"),
128
+ adapter_name="zh_rap_lora",
129
+ with_alpha=True,
130
+ )
131
+ logger.info(
132
+ f"Loading lora weights from: {lora_name_or_path} download path is: {lora_download_path}"
133
+ )
134
  self.lora_path = lora_name_or_path
135
  elif self.lora_path != "none" and lora_name_or_path == "none":
136
  logger.info("No lora weights to load.")
 
145
  text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
146
 
147
  files_exist = (
148
+ os.path.exists(os.path.join(dcae_model_path, "config.json"))
149
+ and os.path.exists(
150
+ os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")
151
+ )
152
+ and os.path.exists(os.path.join(vocoder_model_path, "config.json"))
153
+ and os.path.exists(
154
+ os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors")
155
+ )
156
+ and os.path.exists(os.path.join(ace_step_model_path, "config.json"))
157
+ and os.path.exists(
158
+ os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors")
159
+ )
160
+ and os.path.exists(os.path.join(text_encoder_model_path, "config.json"))
161
+ and os.path.exists(
162
+ os.path.join(text_encoder_model_path, "model.safetensors")
163
+ )
164
+ and os.path.exists(
165
+ os.path.join(text_encoder_model_path, "special_tokens_map.json")
166
+ )
167
+ and os.path.exists(
168
+ os.path.join(text_encoder_model_path, "tokenizer_config.json")
169
+ )
170
+ and os.path.exists(os.path.join(text_encoder_model_path, "tokenizer.json"))
171
  )
172
 
173
  if not files_exist:
174
+ logger.info(
175
+ f"Checkpoint directory {checkpoint_dir} is not complete, downloading from Hugging Face Hub"
176
+ )
177
 
178
  # download music dcae model
179
  os.makedirs(dcae_model_path, exist_ok=True)
180
+ hf_hub_download(
181
+ repo_id=REPO_ID,
182
+ subfolder="music_dcae_f8c8",
183
+ filename="config.json",
184
+ local_dir=checkpoint_dir,
185
+ local_dir_use_symlinks=False,
186
+ )
187
+ hf_hub_download(
188
+ repo_id=REPO_ID,
189
+ subfolder="music_dcae_f8c8",
190
+ filename="diffusion_pytorch_model.safetensors",
191
+ local_dir=checkpoint_dir,
192
+ local_dir_use_symlinks=False,
193
+ )
194
 
195
  # download vocoder model
196
  os.makedirs(vocoder_model_path, exist_ok=True)
197
+ hf_hub_download(
198
+ repo_id=REPO_ID,
199
+ subfolder="music_vocoder",
200
+ filename="config.json",
201
+ local_dir=checkpoint_dir,
202
+ local_dir_use_symlinks=False,
203
+ )
204
+ hf_hub_download(
205
+ repo_id=REPO_ID,
206
+ subfolder="music_vocoder",
207
+ filename="diffusion_pytorch_model.safetensors",
208
+ local_dir=checkpoint_dir,
209
+ local_dir_use_symlinks=False,
210
+ )
211
 
212
  # download ace_step transformer model
213
  os.makedirs(ace_step_model_path, exist_ok=True)
214
+ hf_hub_download(
215
+ repo_id=REPO_ID,
216
+ subfolder="ace_step_transformer",
217
+ filename="config.json",
218
+ local_dir=checkpoint_dir,
219
+ local_dir_use_symlinks=False,
220
+ )
221
+ hf_hub_download(
222
+ repo_id=REPO_ID,
223
+ subfolder="ace_step_transformer",
224
+ filename="diffusion_pytorch_model.safetensors",
225
+ local_dir=checkpoint_dir,
226
+ local_dir_use_symlinks=False,
227
+ )
228
 
229
  # download text encoder model
230
  os.makedirs(text_encoder_model_path, exist_ok=True)
231
+ hf_hub_download(
232
+ repo_id=REPO_ID,
233
+ subfolder="umt5-base",
234
+ filename="config.json",
235
+ local_dir=checkpoint_dir,
236
+ local_dir_use_symlinks=False,
237
+ )
238
+ hf_hub_download(
239
+ repo_id=REPO_ID,
240
+ subfolder="umt5-base",
241
+ filename="model.safetensors",
242
+ local_dir=checkpoint_dir,
243
+ local_dir_use_symlinks=False,
244
+ )
245
+ hf_hub_download(
246
+ repo_id=REPO_ID,
247
+ subfolder="umt5-base",
248
+ filename="special_tokens_map.json",
249
+ local_dir=checkpoint_dir,
250
+ local_dir_use_symlinks=False,
251
+ )
252
+ hf_hub_download(
253
+ repo_id=REPO_ID,
254
+ subfolder="umt5-base",
255
+ filename="tokenizer_config.json",
256
+ local_dir=checkpoint_dir,
257
+ local_dir_use_symlinks=False,
258
+ )
259
+ hf_hub_download(
260
+ repo_id=REPO_ID,
261
+ subfolder="umt5-base",
262
+ filename="tokenizer.json",
263
+ local_dir=checkpoint_dir,
264
+ local_dir_use_symlinks=False,
265
+ )
266
 
267
  logger.info("Models downloaded")
268
 
 
271
  ace_step_checkpoint_path = ace_step_model_path
272
  text_encoder_checkpoint_path = text_encoder_model_path
273
 
274
+ self.music_dcae = MusicDCAE(
275
+ dcae_checkpoint_path=dcae_checkpoint_path,
276
+ vocoder_checkpoint_path=vocoder_checkpoint_path,
277
+ )
278
  self.music_dcae.to(device).eval().to(self.dtype)
279
 
280
+ self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(
281
+ ace_step_checkpoint_path, torch_dtype=self.dtype
282
+ )
283
  self.ace_step_transformer.to(device).eval().to(self.dtype)
284
 
285
  lang_segment = LangSegment()
286
 
287
+ lang_segment.setfilters(
288
+ [
289
+ "af",
290
+ "am",
291
+ "an",
292
+ "ar",
293
+ "as",
294
+ "az",
295
+ "be",
296
+ "bg",
297
+ "bn",
298
+ "br",
299
+ "bs",
300
+ "ca",
301
+ "cs",
302
+ "cy",
303
+ "da",
304
+ "de",
305
+ "dz",
306
+ "el",
307
+ "en",
308
+ "eo",
309
+ "es",
310
+ "et",
311
+ "eu",
312
+ "fa",
313
+ "fi",
314
+ "fo",
315
+ "fr",
316
+ "ga",
317
+ "gl",
318
+ "gu",
319
+ "he",
320
+ "hi",
321
+ "hr",
322
+ "ht",
323
+ "hu",
324
+ "hy",
325
+ "id",
326
+ "is",
327
+ "it",
328
+ "ja",
329
+ "jv",
330
+ "ka",
331
+ "kk",
332
+ "km",
333
+ "kn",
334
+ "ko",
335
+ "ku",
336
+ "ky",
337
+ "la",
338
+ "lb",
339
+ "lo",
340
+ "lt",
341
+ "lv",
342
+ "mg",
343
+ "mk",
344
+ "ml",
345
+ "mn",
346
+ "mr",
347
+ "ms",
348
+ "mt",
349
+ "nb",
350
+ "ne",
351
+ "nl",
352
+ "nn",
353
+ "no",
354
+ "oc",
355
+ "or",
356
+ "pa",
357
+ "pl",
358
+ "ps",
359
+ "pt",
360
+ "qu",
361
+ "ro",
362
+ "ru",
363
+ "rw",
364
+ "se",
365
+ "si",
366
+ "sk",
367
+ "sl",
368
+ "sq",
369
+ "sr",
370
+ "sv",
371
+ "sw",
372
+ "ta",
373
+ "te",
374
+ "th",
375
+ "tl",
376
+ "tr",
377
+ "ug",
378
+ "uk",
379
+ "ur",
380
+ "vi",
381
+ "vo",
382
+ "wa",
383
+ "xh",
384
+ "zh",
385
+ "zu",
386
+ ]
387
+ )
388
  self.lang_segment = lang_segment
389
  self.lyric_tokenizer = VoiceBpeTokenizer()
390
+ text_encoder_model = UMT5EncoderModel.from_pretrained(
391
+ text_encoder_checkpoint_path, torch_dtype=self.dtype
392
+ ).eval()
393
  text_encoder_model = text_encoder_model.to(device).to(self.dtype)
394
  text_encoder_model.requires_grad_(False)
395
  self.text_encoder_model = text_encoder_model
396
+ self.text_tokenizer = AutoTokenizer.from_pretrained(
397
+ text_encoder_checkpoint_path
398
+ )
399
  self.loaded = True
400
 
401
  # compile
 
405
  self.text_encoder_model = torch.compile(self.text_encoder_model)
406
 
407
  def get_text_embeddings(self, texts, device, text_max_length=256):
408
+ inputs = self.text_tokenizer(
409
+ texts,
410
+ return_tensors="pt",
411
+ padding=True,
412
+ truncation=True,
413
+ max_length=text_max_length,
414
+ )
415
  inputs = {key: value.to(device) for key, value in inputs.items()}
416
  if self.text_encoder_model.device != device:
417
  self.text_encoder_model.to(device)
 
420
  last_hidden_states = outputs.last_hidden_state
421
  attention_mask = inputs["attention_mask"]
422
  return last_hidden_states, attention_mask
423
+
424
+ def get_text_embeddings_null(
425
+ self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10
426
+ ):
427
+ inputs = self.text_tokenizer(
428
+ texts,
429
+ return_tensors="pt",
430
+ padding=True,
431
+ truncation=True,
432
+ max_length=text_max_length,
433
+ )
434
  inputs = {key: value.to(device) for key, value in inputs.items()}
435
  if self.text_encoder_model.device != device:
436
  self.text_encoder_model.to(device)
437
+
438
  def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
439
  handlers = []
440
+
441
  def hook(module, input, output):
442
  output[:] *= tau
443
  return output
444
+
445
  for i in range(l_min, l_max):
446
+ handler = (
447
+ self.text_encoder_model.encoder.block[i]
448
+ .layer[0]
449
+ .SelfAttention.q.register_forward_hook(hook)
450
+ )
451
  handlers.append(handler)
452
+
453
  with torch.no_grad():
454
  outputs = self.text_encoder_model(**inputs)
455
  last_hidden_states = outputs.last_hidden_state
456
+
457
  for hook in handlers:
458
  hook.remove()
459
+
460
  return last_hidden_states
461
+
462
  last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
463
  return last_hidden_states
464
 
465
  def set_seeds(self, batch_size, manual_seeds=None):
466
+ processed_input_seeds = None
467
  if manual_seeds is not None:
468
  if isinstance(manual_seeds, str):
469
  if "," in manual_seeds:
470
+ processed_input_seeds = list(map(int, manual_seeds.split(",")))
471
  elif manual_seeds.isdigit():
472
+ processed_input_seeds = int(manual_seeds)
473
+ elif isinstance(manual_seeds, list) and all(
474
+ isinstance(s, int) for s in manual_seeds
475
+ ):
476
+ if len(manual_seeds) > 0:
477
+ processed_input_seeds = list(manual_seeds)
478
+ elif isinstance(manual_seeds, int):
479
+ processed_input_seeds = manual_seeds
480
+ random_generators = [
481
+ torch.Generator(device=self.device) for _ in range(batch_size)
482
+ ]
483
  actual_seeds = []
484
  for i in range(batch_size):
485
+ current_seed_for_generator = None
486
+ if processed_input_seeds is None:
487
+ current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
488
+ elif isinstance(processed_input_seeds, int):
489
+ current_seed_for_generator = processed_input_seeds
490
+ elif isinstance(processed_input_seeds, list):
491
+ if i < len(processed_input_seeds):
492
+ current_seed_for_generator = processed_input_seeds[i]
493
+ else:
494
+ current_seed_for_generator = processed_input_seeds[-1]
495
+ if current_seed_for_generator is None:
496
+ current_seed_for_generator = torch.randint(0, 2**32, (1,)).item()
497
+ random_generators[i].manual_seed(current_seed_for_generator)
498
+ actual_seeds.append(current_seed_for_generator)
499
  return random_generators, actual_seeds
500
 
501
  def get_lang(self, text):
502
  language = "en"
503
+ try:
504
  _ = self.lang_segment.getTexts(text)
505
  langCounts = self.lang_segment.getCounts()
506
  language = langCounts[0][0]
 
534
  else:
535
  token_idx = self.lyric_tokenizer.encode(line, lang)
536
  if debug:
537
+ toks = self.lyric_tokenizer.batch_decode(
538
+ [[tok_id] for tok_id in token_idx]
539
+ )
540
  logger.info(f"debbug {line} --> {lang} --> {toks}")
541
  lyric_token_idx = lyric_token_idx + token_idx + [2]
542
  except Exception as e:
 
565
  attention_mask=None,
566
  momentum_buffer=None,
567
  momentum_buffer_tar=None,
568
+ return_src_pred=True,
569
  ):
570
  noise_pred_src = None
571
  if return_src_pred:
572
+ src_latent_model_input = (
573
+ torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src
574
+ )
575
  timestep = t.expand(src_latent_model_input.shape[0])
576
  # source
577
  noise_pred_src = self.ace_step_transformer(
 
586
  ).sample
587
 
588
  if do_classifier_free_guidance:
589
+ noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(
590
+ 2
591
+ )
592
  if cfg_type == "apg":
593
  noise_pred_src = apg_forward(
594
  pred_cond=noise_pred_with_cond_src,
 
603
  cfg_strength=guidance_scale,
604
  )
605
 
606
+ tar_latent_model_input = (
607
+ torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar
608
+ )
609
  timestep = t.expand(tar_latent_model_input.shape[0])
610
  # target
611
  noise_pred_tar = self.ace_step_transformer(
 
675
  T_steps = infer_steps
676
  frame_length = src_latents.shape[-1]
677
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
678
+
679
+ timesteps, T_steps = retrieve_timesteps(
680
+ scheduler, T_steps, device, timesteps=None
681
+ )
682
 
683
  if do_classifier_free_guidance:
684
  attention_mask = torch.cat([attention_mask] * 2, dim=0)
685
+
686
+ encoder_text_hidden_states = torch.cat(
687
+ [
688
+ encoder_text_hidden_states,
689
+ torch.zeros_like(encoder_text_hidden_states),
690
+ ],
691
+ 0,
692
+ )
693
  text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
694
 
695
+ target_encoder_text_hidden_states = torch.cat(
696
+ [
697
+ target_encoder_text_hidden_states,
698
+ torch.zeros_like(target_encoder_text_hidden_states),
699
+ ],
700
+ 0,
701
+ )
702
+ target_text_attention_mask = torch.cat(
703
+ [target_text_attention_mask] * 2, dim=0
704
+ )
705
 
706
+ speaker_embds = torch.cat(
707
+ [speaker_embds, torch.zeros_like(speaker_embds)], 0
708
+ )
709
+ target_speaker_embeds = torch.cat(
710
+ [target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0
711
+ )
712
 
713
+ lyric_token_ids = torch.cat(
714
+ [lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0
715
+ )
716
  lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
717
 
718
+ target_lyric_token_ids = torch.cat(
719
+ [target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0
720
+ )
721
+ target_lyric_mask = torch.cat(
722
+ [target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0
723
+ )
724
 
725
  momentum_buffer = MomentumBuffer()
726
  momentum_buffer_tar = MomentumBuffer()
 
737
  if i < n_min:
738
  continue
739
 
740
+ t_i = t / 1000
741
 
742
+ if i + 1 < len(timesteps):
743
+ t_im1 = (timesteps[i + 1]) / 1000
744
  else:
745
  t_im1 = torch.zeros_like(t_i).to(t_i.device)
746
 
 
748
  # Calculate the average of the V predictions
749
  V_delta_avg = torch.zeros_like(x_src)
750
  for k in range(n_avg):
751
+ fwd_noise = randn_tensor(
752
+ shape=x_src.shape,
753
+ generator=random_generators,
754
+ device=device,
755
+ dtype=dtype,
756
+ )
757
 
758
  zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise
759
 
 
777
  guidance_scale=guidance_scale,
778
  target_guidance_scale=target_guidance_scale,
779
  attention_mask=attention_mask,
780
+ momentum_buffer=momentum_buffer,
781
  )
782
+ V_delta_avg += (1 / n_avg) * (
783
+ Vt_tar - Vt_src
784
+ ) # - (hfg-1)*( x_src))
785
 
786
  # propagate direct ODE
787
  zt_edit = zt_edit.to(torch.float32)
788
  zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
789
  zt_edit = zt_edit.to(V_delta_avg.dtype)
790
+ else: # i >= T_steps-n_min # regular sampling for last n_min steps
791
  if i == n_max:
792
+ fwd_noise = randn_tensor(
793
+ shape=x_src.shape,
794
+ generator=random_generators,
795
+ device=device,
796
+ dtype=dtype,
797
+ )
798
  scheduler._init_step_index(t)
799
  sigma = scheduler.sigmas[scheduler.step_index]
800
  xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src
801
  xt_tar = zt_edit + xt_src - x_src
802
+
803
  _, Vt_tar = self.calc_v(
804
  zt_src=None,
805
  zt_tar=xt_tar,
 
821
  momentum_buffer_tar=momentum_buffer_tar,
822
  return_src_pred=False,
823
  )
824
+
825
  dtype = Vt_tar.dtype
826
  xt_tar = xt_tar.to(torch.float32)
827
  prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar
828
+ prev_sample = prev_sample.to(dtype)
829
  xt_tar = prev_sample
830
+
831
  target_latents = zt_edit if xt_tar is None else xt_tar
832
  return target_latents
833
 
 
845
  timesteps = scheduler.timesteps.unsqueeze(1).to(gt_latents.dtype)
846
  indices = indices.to(timesteps.device).to(gt_latents.dtype).unsqueeze(1)
847
  nearest_idx = torch.argmin(torch.cdist(indices, timesteps), dim=1)
848
+ sigma = (
849
+ scheduler.sigmas[nearest_idx]
850
+ .flatten()
851
+ .to(gt_latents.device)
852
+ .to(gt_latents.dtype)
853
+ )
854
  while len(sigma.shape) < gt_latents.ndim:
855
  sigma = sigma.unsqueeze(-1)
856
  noisy_image = sigma * noise + (1.0 - sigma) * gt_latents
 
894
  ref_latents=None,
895
  ):
896
 
897
+ logger.info(
898
+ "cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(
899
+ cfg_type, guidance_scale, omega_scale
900
+ )
901
+ )
902
  do_classifier_free_guidance = True
903
  if guidance_scale == 0.0 or guidance_scale == 1.0:
904
  do_classifier_free_guidance = False
905
+
906
  do_double_condition_guidance = False
907
+ if (
908
+ guidance_scale_text is not None
909
+ and guidance_scale_text > 1.0
910
+ and guidance_scale_lyric is not None
911
+ and guidance_scale_lyric > 1.0
912
+ ):
913
  do_double_condition_guidance = True
914
+ logger.info(
915
+ "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(
916
+ do_double_condition_guidance,
917
+ guidance_scale_text,
918
+ guidance_scale_lyric,
919
+ )
920
+ )
921
 
922
  device = encoder_text_hidden_states.device
923
  dtype = encoder_text_hidden_states.dtype
 
933
  num_train_timesteps=1000,
934
  shift=3.0,
935
  )
936
+
937
  frame_length = int(duration * 44100 / 512 / 8)
938
  if src_latents is not None:
939
  frame_length = src_latents.shape[-1]
 
944
  if len(oss_steps) > 0:
945
  infer_steps = max(oss_steps)
946
  scheduler.set_timesteps
947
+ timesteps, num_inference_steps = retrieve_timesteps(
948
+ scheduler,
949
+ num_inference_steps=infer_steps,
950
+ device=device,
951
+ timesteps=None,
952
+ )
953
  new_timesteps = torch.zeros(len(oss_steps), dtype=dtype, device=device)
954
  for idx in range(len(oss_steps)):
955
+ new_timesteps[idx] = timesteps[oss_steps[idx] - 1]
956
  num_inference_steps = len(oss_steps)
957
  sigmas = (new_timesteps / 1000).float().cpu().numpy()
958
+ timesteps, num_inference_steps = retrieve_timesteps(
959
+ scheduler,
960
+ num_inference_steps=num_inference_steps,
961
+ device=device,
962
+ sigmas=sigmas,
963
+ )
964
+ logger.info(
965
+ f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}"
966
+ )
967
  else:
968
+ timesteps, num_inference_steps = retrieve_timesteps(
969
+ scheduler,
970
+ num_inference_steps=infer_steps,
971
+ device=device,
972
+ timesteps=None,
973
+ )
974
+
975
+ target_latents = randn_tensor(
976
+ shape=(bsz, 8, 16, frame_length),
977
+ generator=random_generators,
978
+ device=device,
979
+ dtype=dtype,
980
+ )
981
+
982
  is_repaint = False
983
+ is_extend = False
984
  if add_retake_noise:
985
  n_min = int(infer_steps * (1 - retake_variance))
986
+ retake_variance = (
987
+ torch.tensor(retake_variance * math.pi / 2).to(device).to(dtype)
988
+ )
989
+ retake_latents = randn_tensor(
990
+ shape=(bsz, 8, 16, frame_length),
991
+ generator=retake_random_generators,
992
+ device=device,
993
+ dtype=dtype,
994
+ )
995
  repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
996
  repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
997
  x0 = src_latents
998
  # retake
999
+ is_repaint = repaint_end_frame - repaint_start_frame != frame_length
1000
+
1001
  is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
1002
  if is_extend:
1003
  is_repaint = True
 
1005
  # TODO: train a mask aware repainting controlnet
1006
  # to make sure mean = 0, std = 1
1007
  if not is_repaint:
1008
+ target_latents = (
1009
+ torch.cos(retake_variance) * target_latents
1010
+ + torch.sin(retake_variance) * retake_latents
1011
+ )
1012
  elif not is_extend:
1013
+ # if repaint_end_frame
1014
+ repaint_mask = torch.zeros(
1015
+ (bsz, 8, 16, frame_length), device=device, dtype=dtype
1016
+ )
1017
  repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
1018
+ repaint_noise = (
1019
+ torch.cos(retake_variance) * target_latents
1020
+ + torch.sin(retake_variance) * retake_latents
1021
+ )
1022
+ repaint_noise = torch.where(
1023
+ repaint_mask == 1.0, repaint_noise, target_latents
1024
+ )
1025
  zt_edit = x0.clone()
1026
  z0 = repaint_noise
1027
  elif is_extend:
 
1037
  if repaint_start_frame < 0:
1038
  left_pad_frame_length = abs(repaint_start_frame)
1039
  frame_length = left_pad_frame_length + gt_latents.shape[-1]
1040
+ extend_gt_latents = torch.nn.functional.pad(
1041
+ gt_latents, (left_pad_frame_length, 0), "constant", 0
1042
+ )
1043
  if frame_length > max_infer_fame_length:
1044
  right_trim_length = frame_length - max_infer_fame_length
1045
+ extend_gt_latents = extend_gt_latents[
1046
+ :, :, :, :max_infer_fame_length
1047
+ ]
1048
+ to_right_pad_gt_latents = extend_gt_latents[
1049
+ :, :, :, -right_trim_length:
1050
+ ]
1051
  frame_length = max_infer_fame_length
1052
  repaint_start_frame = 0
1053
  gt_latents = extend_gt_latents
1054
+
1055
  if repaint_end_frame > src_latents_length:
1056
  right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
1057
  frame_length = gt_latents.shape[-1] + right_pad_frame_length
1058
+ extend_gt_latents = torch.nn.functional.pad(
1059
+ gt_latents, (0, right_pad_frame_length), "constant", 0
1060
+ )
1061
  if frame_length > max_infer_fame_length:
1062
  left_trim_length = frame_length - max_infer_fame_length
1063
+ extend_gt_latents = extend_gt_latents[
1064
+ :, :, :, -max_infer_fame_length:
1065
+ ]
1066
+ to_left_pad_gt_latents = extend_gt_latents[
1067
+ :, :, :, :left_trim_length
1068
+ ]
1069
  frame_length = max_infer_fame_length
1070
  repaint_end_frame = frame_length
1071
  gt_latents = extend_gt_latents
1072
 
1073
+ repaint_mask = torch.zeros(
1074
+ (bsz, 8, 16, frame_length), device=device, dtype=dtype
1075
+ )
1076
  if left_pad_frame_length > 0:
1077
+ repaint_mask[:, :, :, :left_pad_frame_length] = 1.0
1078
  if right_pad_frame_length > 0:
1079
+ repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0
1080
  x0 = gt_latents
1081
  padd_list = []
1082
  if left_pad_frame_length > 0:
1083
  padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
1084
+ padd_list.append(
1085
+ target_latents[
1086
+ :,
1087
+ :,
1088
+ :,
1089
+ left_trim_length : target_latents.shape[-1] - right_trim_length,
1090
+ ]
1091
+ )
1092
  if right_pad_frame_length > 0:
1093
  padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
1094
  target_latents = torch.cat(padd_list, dim=-1)
1095
+ assert (
1096
+ target_latents.shape[-1] == x0.shape[-1]
1097
+ ), f"{target_latents.shape=} {x0.shape=}"
1098
  zt_edit = x0.clone()
1099
  z0 = target_latents
1100
 
1101
  init_timestep = 1000
1102
  if audio2audio_enable and ref_latents is not None:
1103
+ target_latents, init_timestep = self.add_latents_noise(
1104
+ gt_latents=ref_latents,
1105
+ variance=(1 - ref_audio_strength),
1106
+ noise=target_latents,
1107
+ scheduler=scheduler,
1108
+ )
1109
 
1110
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
1111
+
1112
  # guidance interval
1113
  start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
1114
  end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
1115
+ logger.info(
1116
+ f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}"
1117
+ )
1118
 
1119
  momentum_buffer = MomentumBuffer()
1120
 
1121
  def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
1122
  handlers = []
1123
+
1124
  def hook(module, input, output):
1125
  output[:] *= tau
1126
  return output
1127
+
1128
  for i in range(l_min, l_max):
1129
+ handler = self.ace_step_transformer.lyric_encoder.encoders[
1130
+ i
1131
+ ].self_attn.linear_q.register_forward_hook(hook)
1132
  handlers.append(handler)
1133
+
1134
+ encoder_hidden_states, encoder_hidden_mask = (
1135
+ self.ace_step_transformer.encode(**inputs)
1136
+ )
1137
+
1138
  for hook in handlers:
1139
  hook.remove()
1140
+
1141
  return encoder_hidden_states
1142
 
1143
  # P(speaker, text, lyric)
 
1154
  encoder_hidden_states_null = forward_encoder_with_temperature(
1155
  self,
1156
  inputs={
1157
+ "encoder_text_hidden_states": (
1158
+ encoder_text_hidden_states_null
1159
+ if encoder_text_hidden_states_null is not None
1160
+ else torch.zeros_like(encoder_text_hidden_states)
1161
+ ),
1162
  "text_attention_mask": text_attention_mask,
1163
  "speaker_embeds": torch.zeros_like(speaker_embds),
1164
  "lyric_token_idx": lyric_token_ids,
1165
  "lyric_mask": lyric_mask,
1166
+ },
1167
  )
1168
  else:
1169
  # P(null_speaker, null_text, null_lyric)
 
1174
  torch.zeros_like(lyric_token_ids),
1175
  lyric_mask,
1176
  )
1177
+
1178
  encoder_hidden_states_no_lyric = None
1179
  if do_double_condition_guidance:
1180
  # P(null_speaker, text, lyric_weaker)
 
1187
  "speaker_embeds": torch.zeros_like(speaker_embds),
1188
  "lyric_token_idx": lyric_token_ids,
1189
  "lyric_mask": lyric_mask,
1190
+ },
1191
  )
1192
  # P(null_speaker, text, no_lyric)
1193
  else:
 
1199
  lyric_mask,
1200
  )
1201
 
1202
+ def forward_diffusion_with_temperature(
1203
+ self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20
1204
+ ):
1205
  handlers = []
1206
+
1207
  def hook(module, input, output):
1208
  output[:] *= tau
1209
  return output
1210
+
1211
  for i in range(l_min, l_max):
1212
+ handler = self.ace_step_transformer.transformer_blocks[
1213
+ i
1214
+ ].attn.to_q.register_forward_hook(hook)
1215
  handlers.append(handler)
1216
+ handler = self.ace_step_transformer.transformer_blocks[
1217
+ i
1218
+ ].cross_attn.to_q.register_forward_hook(hook)
1219
  handlers.append(handler)
1220
 
1221
+ sample = self.ace_step_transformer.decode(
1222
+ hidden_states=hidden_states, timestep=timestep, **inputs
1223
+ ).sample
1224
+
1225
  for hook in handlers:
1226
  hook.remove()
1227
+
1228
  return sample
1229
+
1230
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
1231
 
1232
  if t > init_timestep:
 
1249
  # compute current guidance scale
1250
  if guidance_interval_decay > 0:
1251
  # Linearly interpolate to calculate the current guidance scale
1252
+ progress = (i - start_idx) / (
1253
+ end_idx - start_idx - 1
1254
+ ) # ๅฝ’ไธ€ๅŒ–ๅˆฐ[0,1]
1255
+ current_guidance_scale = (
1256
+ guidance_scale
1257
+ - (guidance_scale - min_guidance_scale)
1258
+ * progress
1259
+ * guidance_interval_decay
1260
+ )
1261
  else:
1262
  current_guidance_scale = guidance_scale
1263
 
 
1275
  ).sample
1276
 
1277
  noise_pred_with_only_text_cond = None
1278
+ if (
1279
+ do_double_condition_guidance
1280
+ and encoder_hidden_states_no_lyric is not None
1281
+ ):
1282
  noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
1283
  hidden_states=latent_model_input,
1284
  attention_mask=attention_mask,
 
1310
  timestep=timestep,
1311
  ).sample
1312
 
1313
+ if (
1314
+ do_double_condition_guidance
1315
+ and noise_pred_with_only_text_cond is not None
1316
+ ):
1317
  noise_pred = cfg_double_condition_forward(
1318
  cond_output=noise_pred_with_cond,
1319
  uncond_output=noise_pred_uncond,
 
1342
  guidance_scale=current_guidance_scale,
1343
  i=i,
1344
  zero_steps=zero_steps,
1345
+ use_zero_init=use_zero_init,
1346
  )
1347
  else:
1348
  latent_model_input = latents
 
1357
  ).sample
1358
 
1359
  if is_repaint and i >= n_min:
1360
+ t_i = t / 1000
1361
+ if i + 1 < len(timesteps):
1362
+ t_im1 = (timesteps[i + 1]) / 1000
1363
  else:
1364
  t_im1 = torch.zeros_like(t_i).to(t_i.device)
1365
  dtype = noise_pred.dtype
 
1368
  prev_sample = prev_sample.to(dtype)
1369
  target_latents = prev_sample
1370
  zt_src = (1 - t_im1) * x0 + (t_im1) * z0
1371
+ target_latents = torch.where(
1372
+ repaint_mask == 1.0, target_latents, zt_src
1373
+ )
1374
  else:
1375
+ target_latents = scheduler.step(
1376
+ model_output=noise_pred,
1377
+ timestep=t,
1378
+ sample=target_latents,
1379
+ return_dict=False,
1380
+ omega=omega_scale,
1381
+ )[0]
1382
 
1383
  if is_extend:
1384
  if to_right_pad_gt_latents is not None:
1385
+ target_latents = torch.cat(
1386
+ [target_latents, to_right_pad_gt_latents], dim=-1
1387
+ )
1388
  if to_left_pad_gt_latents is not None:
1389
+ target_latents = torch.cat(
1390
+ [to_right_pad_gt_latents, target_latents], dim=0
1391
+ )
1392
  return target_latents
1393
 
1394
+ def latents2audio(
1395
+ self,
1396
+ latents,
1397
+ target_wav_duration_second=30,
1398
+ sample_rate=48000,
1399
+ save_path=None,
1400
+ format="mp3",
1401
+ ):
1402
  output_audio_paths = []
1403
  bs = latents.shape[0]
1404
  audio_lengths = [target_wav_duration_second * sample_rate] * bs
 
1407
  _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
1408
  pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
1409
  for i in tqdm(range(bs)):
1410
+ output_audio_path = self.save_wav_file(
1411
+ pred_wavs[i], i, sample_rate=sample_rate
1412
+ )
1413
  output_audio_paths.append(output_audio_path)
1414
  return output_audio_paths
1415
 
1416
+ def save_wav_file(
1417
+ self, target_wav, idx, save_path=None, sample_rate=48000, format="mp3"
1418
+ ):
1419
  if save_path is None:
1420
  logger.warning("save_path is None, using default path ./outputs/")
1421
  base_path = f"./outputs"
 
1424
  base_path = save_path
1425
  ensure_directory_exists(base_path)
1426
 
1427
+ output_path_flac = (
1428
+ f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.{format}"
1429
+ )
1430
  target_wav = target_wav.float()
1431
+ torchaudio.save(
1432
+ output_path_flac,
1433
+ target_wav,
1434
+ sample_rate=sample_rate,
1435
+ format=format,
1436
+ compression=torio.io.CodecConfig(bit_rate=320000),
1437
+ )
1438
  return output_path_flac
1439
 
1440
  def infer_latents(self, input_audio_path):
 
1460
  omega_scale: int = 10.0,
1461
  manual_seeds: list = None,
1462
  guidance_interval: float = 0.5,
1463
+ guidance_interval_decay: float = 0.0,
1464
  min_guidance_scale: float = 3.0,
1465
  use_erg_tag: bool = True,
1466
  use_erg_lyric: bool = True,
 
1503
  start_time = time.time()
1504
 
1505
  random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
1506
+ retake_random_generators, actual_retake_seeds = self.set_seeds(
1507
+ batch_size, retake_seeds
1508
+ )
1509
 
1510
  if isinstance(oss_steps, str) and len(oss_steps) > 0:
1511
  oss_steps = list(map(int, oss_steps.split(",")))
1512
  else:
1513
  oss_steps = []
1514
+
1515
  texts = [prompt]
1516
+ encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(
1517
+ texts, self.device
1518
+ )
1519
  encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
1520
  text_attention_mask = text_attention_mask.repeat(batch_size, 1)
1521
 
1522
  encoder_text_hidden_states_null = None
1523
  if use_erg_tag:
1524
+ encoder_text_hidden_states_null = self.get_text_embeddings_null(
1525
+ texts, self.device
1526
+ )
1527
+ encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(
1528
+ batch_size, 1, 1
1529
+ )
1530
 
1531
  # not support for released checkpoint
1532
  speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
 
1537
  if len(lyrics) > 0:
1538
  lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug)
1539
  lyric_mask = [1] * len(lyric_token_idx)
1540
+ lyric_token_idx = (
1541
+ torch.tensor(lyric_token_idx)
1542
+ .unsqueeze(0)
1543
+ .to(self.device)
1544
+ .repeat(batch_size, 1)
1545
+ )
1546
+ lyric_mask = (
1547
+ torch.tensor(lyric_mask)
1548
+ .unsqueeze(0)
1549
+ .to(self.device)
1550
+ .repeat(batch_size, 1)
1551
+ )
1552
 
1553
  if audio_duration <= 0:
1554
  audio_duration = random.uniform(30.0, 240.0)
 
1563
  if task == "retake":
1564
  repaint_start = 0
1565
  repaint_end = audio_duration
1566
+
1567
  src_latents = None
1568
  if src_audio_path is not None:
1569
+ assert src_audio_path is not None and task in (
1570
+ "repaint",
1571
+ "edit",
1572
+ "extend",
1573
+ ), "src_audio_path is required for retake/repaint/extend task"
1574
+ assert os.path.exists(
1575
+ src_audio_path
1576
+ ), f"src_audio_path {src_audio_path} does not exist"
1577
  src_latents = self.infer_latents(src_audio_path)
1578
 
1579
  ref_latents = None
1580
  if ref_audio_input is not None and audio2audio_enable:
1581
+ assert (
1582
+ ref_audio_input is not None
1583
+ ), "ref_audio_input is required for audio2audio task"
1584
  assert os.path.exists(
1585
  ref_audio_input
1586
  ), f"ref_audio_input {ref_audio_input} does not exist"
 
1588
 
1589
  if task == "edit":
1590
  texts = [edit_target_prompt]
1591
+ target_encoder_text_hidden_states, target_text_attention_mask = (
1592
+ self.get_text_embeddings(texts, self.device)
1593
+ )
1594
+ target_encoder_text_hidden_states = (
1595
+ target_encoder_text_hidden_states.repeat(batch_size, 1, 1)
1596
+ )
1597
+ target_text_attention_mask = target_text_attention_mask.repeat(
1598
+ batch_size, 1
1599
+ )
1600
 
1601
+ target_lyric_token_idx = (
1602
+ torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1603
+ )
1604
+ target_lyric_mask = (
1605
+ torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
1606
+ )
1607
  if len(edit_target_lyrics) > 0:
1608
+ target_lyric_token_idx = self.tokenize_lyrics(
1609
+ edit_target_lyrics, debug=True
1610
+ )
1611
  target_lyric_mask = [1] * len(target_lyric_token_idx)
1612
+ target_lyric_token_idx = (
1613
+ torch.tensor(target_lyric_token_idx)
1614
+ .unsqueeze(0)
1615
+ .to(self.device)
1616
+ .repeat(batch_size, 1)
1617
+ )
1618
+ target_lyric_mask = (
1619
+ torch.tensor(target_lyric_mask)
1620
+ .unsqueeze(0)
1621
+ .to(self.device)
1622
+ .repeat(batch_size, 1)
1623
+ )
1624
 
1625
  target_speaker_embeds = speaker_embeds.clone()
1626
 
 
1636
  target_lyric_token_ids=target_lyric_token_idx,
1637
  target_lyric_mask=target_lyric_mask,
1638
  src_latents=src_latents,
1639
+ random_generators=retake_random_generators, # more diversity
1640
  infer_steps=infer_step,
1641
  guidance_scale=guidance_scale,
1642
  n_min=edit_n_min,
 
1724
  "repaint_end": repaint_end,
1725
  "edit_n_min": edit_n_min,
1726
  "edit_n_max": edit_n_max,
1727
+ "edit_n_avg": edit_n_avg,
1728
  "src_audio_path": src_audio_path,
1729
  "edit_target_prompt": edit_target_prompt,
1730
  "edit_target_lyrics": edit_target_lyrics,
 
1734
  }
1735
  # save input_params_json
1736
  for output_audio_path in output_paths:
1737
+ input_params_json_save_path = output_audio_path.replace(
1738
+ f".{format}", "_input_params.json"
1739
+ )
1740
  input_params_json["audio_path"] = output_audio_path
1741
  with open(input_params_json_save_path, "w", encoding="utf-8") as f:
1742
  json.dump(input_params_json, f, indent=4, ensure_ascii=False)