Spaces:
Running
on
Zero
Running
on
Zero
hkchen
commited on
Commit
·
34ded66
1
Parent(s):
7d15a50
update v1.2 full web code
Browse files- app.py +54 -43
- diffrhythm/infer/infer.py +4 -1
- diffrhythm/infer/infer_utils.py +20 -9
- diffrhythm/model/cfm.py +4 -2
- diffrhythm/model/dit.py +4 -1
app.py
CHANGED
@@ -27,19 +27,22 @@ from diffrhythm.infer.infer import inference
|
|
27 |
|
28 |
MAX_SEED = np.iinfo(np.int32).max
|
29 |
device='cuda'
|
30 |
-
cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(device)
|
31 |
cfm = torch.compile(cfm)
|
32 |
|
33 |
@spaces.GPU
|
34 |
-
def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
|
35 |
-
|
|
|
|
|
|
|
36 |
sway_sampling_coef = -1 if steps < 32 else None
|
37 |
if randomize_seed:
|
38 |
seed = random.randint(0, MAX_SEED)
|
39 |
torch.manual_seed(seed)
|
40 |
vocal_flag = False
|
41 |
try:
|
42 |
-
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
|
43 |
if current_prompt_type == 'audio':
|
44 |
style_prompt, vocal_flag = get_audio_style_prompt(muq, ref_audio_path)
|
45 |
else:
|
@@ -59,7 +62,7 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
|
|
59 |
eval_muq=eval_muq,
|
60 |
cond=latent_prompt,
|
61 |
text=lrc_prompt,
|
62 |
-
duration=
|
63 |
style_prompt=style_prompt,
|
64 |
negative_style_prompt=negative_style_prompt,
|
65 |
steps=steps,
|
@@ -71,6 +74,7 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
|
|
71 |
odeint_method=odeint_method,
|
72 |
pred_frames=pred_frames,
|
73 |
batch_infer_num=batch_infer_num,
|
|
|
74 |
)
|
75 |
return generated_song
|
76 |
|
@@ -194,7 +198,7 @@ with gr.Blocks(css=css) as demo:
|
|
194 |
lines=12,
|
195 |
max_lines=50,
|
196 |
elem_classes="lyrics-scroll-box",
|
197 |
-
value="""[00:04.
|
198 |
)
|
199 |
|
200 |
current_prompt_type = gr.State(value="audio")
|
@@ -217,35 +221,39 @@ with gr.Blocks(css=css) as demo:
|
|
217 |
with gr.Column():
|
218 |
with gr.Accordion("Best Practices Guide", open=True):
|
219 |
gr.Markdown("""
|
220 |
-
1. **Lyrics Format Requirements**
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
2. **
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
3. **
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
6. **Others**
|
245 |
-
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
246 |
-
|
247 |
""")
|
248 |
# Music_Duration = gr.Radio(["95s", "285s"], label="Music Duration", value="95s")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
preference_infer = gr.Radio(["quality first", "speed first"], label="Preference", value="quality first")
|
250 |
lyrics_btn = gr.Button("Generate", variant="primary")
|
251 |
audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
|
@@ -277,11 +285,13 @@ with gr.Blocks(css=css) as demo:
|
|
277 |
interactive=True,
|
278 |
elem_id="step_slider"
|
279 |
)
|
280 |
-
edit = gr.Checkbox(label="edit", value=False)
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
285 |
|
286 |
odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
|
287 |
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
|
@@ -324,14 +334,15 @@ with gr.Blocks(css=css) as demo:
|
|
324 |
|
325 |
gr.Examples(
|
326 |
examples=[
|
327 |
-
["""[00:04.
|
328 |
-
["""[00:00.
|
329 |
-
["""[00:00.27]只因你太美 baby 只因你太美 baby\n[00:08.95]只因你实在是太美 baby\n[00:13.99]只因你太美 baby\n[00:18.89]迎面走来的你让我如此蠢蠢欲动\n[00:20.88]这种感觉我从未有\n[00:21.79]Cause I got a crush on you who you\n[00:25.74]你是我的我是你的谁\n[00:28.09]再多一眼看一眼就会爆炸\n[00:30.31]再近一点靠近点快被融化\n[00:32.49]想要把你占为己有 baby bae\n[00:34.60]不管走到哪里\n[00:35.44]都会想起的人是你 you you\n[00:38.12]我应该拿你怎样\n[00:39.61]Uh 所有人都在看着你\n[00:42.36]我的心总是不安\n[00:44.18]Oh 我现在已病入膏肓\n[00:46.63]Eh oh\n[00:47.84]难道真的因你而疯狂吗\n[00:51.57]我本来不是这种人\n[00:53.59]因你变成奇怪的人\n[00:55.77]第一次呀变成这样的我\n[01:01.23]不管我怎么去否认\n[01:03.21]只因你太美 baby 只因你太美 baby\n[01:11.46]只因你实在是太美 baby\n[01:16.75]只因你太美 baby\n[01:21.09]Oh eh oh\n[01:22.82]现在确认地告诉我\n[01:25.26]Oh eh oh\n[01:27.31]你到底属于谁\n[01:29.98]Oh eh oh\n[01:31.70]现在确认地告诉我\n[01:34.45]Oh eh oh\n"""]
|
|
|
330 |
],
|
331 |
|
332 |
inputs=[lrc],
|
333 |
label="Lrc Examples",
|
334 |
-
examples_per_page=
|
335 |
elem_id="lrc-examples-container",
|
336 |
)
|
337 |
|
@@ -426,7 +437,7 @@ with gr.Blocks(css=css) as demo:
|
|
426 |
|
427 |
lyrics_btn.click(
|
428 |
fn=infer_music,
|
429 |
-
inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method, preference_infer,
|
430 |
outputs=audio_output
|
431 |
)
|
432 |
|
|
|
27 |
|
28 |
MAX_SEED = np.iinfo(np.int32).max
|
29 |
device='cuda'
|
30 |
+
cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(max_frames=6144, device=device)
|
31 |
cfm = torch.compile(cfm)
|
32 |
|
33 |
@spaces.GPU
|
34 |
+
def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
|
35 |
+
randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav',
|
36 |
+
odeint_method='euler', preference_infer="quality first", Music_Duration=285, edit=False,
|
37 |
+
edit_segments=None, device='cuda'):
|
38 |
+
max_frames = 2048 if Music_Duration == 95 else 6144
|
39 |
sway_sampling_coef = -1 if steps < 32 else None
|
40 |
if randomize_seed:
|
41 |
seed = random.randint(0, MAX_SEED)
|
42 |
torch.manual_seed(seed)
|
43 |
vocal_flag = False
|
44 |
try:
|
45 |
+
lrc_prompt, start_time, end_frame, song_duration = get_lrc_token(max_frames, lrc, tokenizer, Music_Duration, device)
|
46 |
if current_prompt_type == 'audio':
|
47 |
style_prompt, vocal_flag = get_audio_style_prompt(muq, ref_audio_path)
|
48 |
else:
|
|
|
62 |
eval_muq=eval_muq,
|
63 |
cond=latent_prompt,
|
64 |
text=lrc_prompt,
|
65 |
+
duration=end_frame,
|
66 |
style_prompt=style_prompt,
|
67 |
negative_style_prompt=negative_style_prompt,
|
68 |
steps=steps,
|
|
|
74 |
odeint_method=odeint_method,
|
75 |
pred_frames=pred_frames,
|
76 |
batch_infer_num=batch_infer_num,
|
77 |
+
song_duration=song_duration
|
78 |
)
|
79 |
return generated_song
|
80 |
|
|
|
198 |
lines=12,
|
199 |
max_lines=50,
|
200 |
elem_classes="lyrics-scroll-box",
|
201 |
+
value="""[00:04.074] Tell me that I'm special\n[00:06.226] Tell me I look pretty\n[00:08.175] Tell me I'm a little angel\n[00:10.175] Sweetheart of your city\n[00:13.307] Say what I'm dying to hear\n[00:17.058] 'Cause I'm dying to hear you\n[00:20.523] Tell me I'm that new thing\n[00:22.571] Tell me that I'm relevant\n[00:24.723] Tell me that I got a big heart\n[00:26.723] Then back it up with evidence\n[00:29.408] I need it and I don't know why\n[00:33.907] This late at night\n[00:36.139] Isn't it lonely\n[00:38.991] I'd do anything to make you want me\n[00:43.222] I'd give it all up if you told me\n[00:47.339] That I'd be\n[00:49.157] The number one girl in your eyes\n[00:52.506] Your one and only\n[00:55.437] So what's it gon' take for you to want me\n[00:59.589] I'd give it all up if you told me\n[01:03.674] That I'd be\n[01:05.823] The number one girl in your eyes\n[01:10.841] Tell me I'm going real big places\n[01:14.055] Down to earth, so friendly\n[01:16.105] And even through all the phases\n[01:18.256] Tell me you accept me\n[01:21.155] Well, that's all I'm dying to hear\n[01:24.937] Yeah, I'm dying to hear you\n[01:28.521] Tell me that you need me\n[01:30.437] Tell me that I'm loved\n[01:32.740] Tell me that I'm worth it\n[01:34.605] And that I'm enough\n[01:37.571] I need it and I don't know why\n[01:41.889] This late at night\n[01:43.805] Isn't it lonely\n[01:46.871] I'd do anything to make you want me\n[01:51.004] I'd give it all up if you told me\n[01:55.237] That I'd be\n[01:57.089] The number one girl in your eyes\n[02:00.325] Your one and only\n[02:03.305] So what's it gon' take for you to want me\n[02:07.355] I'd give it all up if you told me\n[02:11.589] That I'd be\n[02:13.623] The number one girl in your eyes\n[02:16.804] The girl in your eyes\n[02:20.823] The girl in your eyes\n[02:26.055] Tell me I'm the number one girl\n[02:28.355] I'm the number one girl in your eyes\n[02:33.172] The girl in your eyes\n[02:37.321] The girl in your eyes\n[02:42.472] Tell me I'm the number one girl\n[02:44.539] I'm the number one girl in your eyes\n[02:49.605] Well isn't it lonely\n[02:52.488] I'd do anything to make you want me\n[02:56.637] I'd give it all up if you told me\n[03:00.888] That I'd be\n[03:03.172] The number one girl in your eyes\n[03:06.272] Your one and only\n[03:09.160] So what's it gon' take for you to want me\n[03:13.056] I'd give it all up if you told me\n[03:17.305] That I'd be\n[03:19.488] The number one girl in your eyes\n[03:25.420] The number one girl in your eyes\n"""
|
202 |
)
|
203 |
|
204 |
current_prompt_type = gr.State(value="audio")
|
|
|
221 |
with gr.Column():
|
222 |
with gr.Accordion("Best Practices Guide", open=True):
|
223 |
gr.Markdown("""
|
224 |
+
1. **Lyrics Format Requirements**
|
225 |
+
- Each line must follow: `[mm:ss.xx]Lyric content`
|
226 |
+
- Example of valid format:
|
227 |
+
```
|
228 |
+
[00:10.00]Moonlight spills through broken blinds
|
229 |
+
[00:13.20]Your shadow dances on the dashboard shrine
|
230 |
+
```
|
231 |
+
|
232 |
+
2. **Generation Duration Limits**
|
233 |
+
- The generated music must be between **95 seconds (1:35)** and **285 seconds (4:45)** in length
|
234 |
+
- The latest valid lyric timestamp cannot exceed **04:45.00 (285s)**
|
235 |
+
|
236 |
+
3. **Audio Prompt Requirements**
|
237 |
+
- Reference audio should be ≥ 1 second, Audio >10 seconds will be randomly clipped into 10 seconds
|
238 |
+
- For optimal results, the 10-second clips should be carefully selected
|
239 |
+
- Shorter clips may lead to incoherent generation
|
240 |
+
|
241 |
+
4. **Supported Languages**
|
242 |
+
- Chinese and English
|
243 |
+
- More languages comming soon
|
244 |
+
|
245 |
+
5. **Others**
|
246 |
+
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
|
|
|
|
|
|
|
|
247 |
""")
|
248 |
# Music_Duration = gr.Radio(["95s", "285s"], label="Music Duration", value="95s")
|
249 |
+
Music_Duration = gr.Slider(
|
250 |
+
minimum=95,
|
251 |
+
maximum=285,
|
252 |
+
step=1,
|
253 |
+
value=95,
|
254 |
+
label="Music Duration (s)",
|
255 |
+
interactive=True
|
256 |
+
)
|
257 |
preference_infer = gr.Radio(["quality first", "speed first"], label="Preference", value="quality first")
|
258 |
lyrics_btn = gr.Button("Generate", variant="primary")
|
259 |
audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
|
|
|
285 |
interactive=True,
|
286 |
elem_id="step_slider"
|
287 |
)
|
288 |
+
# edit = gr.Checkbox(label="edit", value=False)
|
289 |
+
# edit = False
|
290 |
+
# preference_infer = gr.Radio(["quality first", "speed first"], label="Preference", value="quality first")
|
291 |
+
# edit_segments = gr.Textbox(
|
292 |
+
# label="Edit Segments",
|
293 |
+
# placeholder="Time segments to edit (in seconds). Format: `[[start1,end1],...]",
|
294 |
+
# )
|
295 |
|
296 |
odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
|
297 |
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
|
|
|
334 |
|
335 |
gr.Examples(
|
336 |
examples=[
|
337 |
+
["""[00:04.074] Tell me that I'm special\n[00:06.226] Tell me I look pretty\n[00:08.175] Tell me I'm a little angel\n[00:10.175] Sweetheart of your city\n[00:13.307] Say what I'm dying to hear\n[00:17.058] 'Cause I'm dying to hear you\n[00:20.523] Tell me I'm that new thing\n[00:22.571] Tell me that I'm relevant\n[00:24.723] Tell me that I got a big heart\n[00:26.723] Then back it up with evidence\n[00:29.408] I need it and I don't know why\n[00:33.907] This late at night\n[00:36.139] Isn't it lonely\n[00:38.991] I'd do anything to make you want me\n[00:43.222] I'd give it all up if you told me\n[00:47.339] That I'd be\n[00:49.157] The number one girl in your eyes\n[00:52.506] Your one and only\n[00:55.437] So what's it gon' take for you to want me\n[00:59.589] I'd give it all up if you told me\n[01:03.674] That I'd be\n[01:05.823] The number one girl in your eyes\n[01:10.841] Tell me I'm going real big places\n[01:14.055] Down to earth, so friendly\n[01:16.105] And even through all the phases\n[01:18.256] Tell me you accept me\n[01:21.155] Well, that's all I'm dying to hear\n[01:24.937] Yeah, I'm dying to hear you\n[01:28.521] Tell me that you need me\n[01:30.437] Tell me that I'm loved\n[01:32.740] Tell me that I'm worth it\n[01:34.605] And that I'm enough\n[01:37.571] I need it and I don't know why\n[01:41.889] This late at night\n[01:43.805] Isn't it lonely\n[01:46.871] I'd do anything to make you want me\n[01:51.004] I'd give it all up if you told me\n[01:55.237] That I'd be\n[01:57.089] The number one girl in your eyes\n[02:00.325] Your one and only\n[02:03.305] So what's it gon' take for you to want me\n[02:07.355] I'd give it all up if you told me\n[02:11.589] That I'd be\n[02:13.623] The number one girl in your eyes\n[02:16.804] The girl in your eyes\n[02:20.823] The girl in your eyes\n[02:26.055] Tell me I'm the number one girl\n[02:28.355] I'm the number one girl in your eyes\n[02:33.172] The girl in your eyes\n[02:37.321] The girl in your eyes\n[02:42.472] Tell me I'm the number one girl\n[02:44.539] I'm the number one girl in your eyes\n[02:49.605] Well isn't it lonely\n[02:52.488] I'd do anything to make you want me\n[02:56.637] I'd give it all up if you told me\n[03:00.888] That I'd be\n[03:03.172] The number one girl in your eyes\n[03:06.272] Your one and only\n[03:09.160] So what's it gon' take for you to want me\n[03:13.056] I'd give it all up if you told me\n[03:17.305] That I'd be\n[03:19.488] The number one girl in your eyes\n[03:25.420] The number one girl in your eyes\n"""],
|
338 |
+
["""[00:00.133]Abracadabra, abracadabra\n[00:03.985]Abracadabra, abracadabra\n[00:15.358]Pay the toll to the angels\n[00:18.694]Drawin' circles in the clouds\n[00:22.966]Keep your mind on the distance\n[00:26.321]When the devil turns around\n[00:30.540]Hold me in your heart tonight\n[00:33.751]In the magic of the dark moonlight\n[00:38.189]Save me from this empty fight\n[00:43.521]In the game of life\n[00:45.409]Like a poem said by a lady in red\n[00:49.088]You hear the last few words of your life\n[00:53.013]With a haunting dance, now you're both in a trance\n[00:56.687]It's time to cast your spell on the night\n[01:01.131]Abracadabra amor-oo-na-na\n[01:04.394]Abracadabra morta-oo-gaga\n[01:08.778]Abracadabra, abra-ooh-na-na\n[01:12.063]In her tongue, she's sayin'\n[01:14.367]Death or love tonight\n[01:18.249]Abracadabra, abracadabra\n[01:22.136]Abracadabra, abracadabra\n[01:25.859]Feel the beat under your feet\n[01:27.554]The floor's on fire\n[01:29.714]Abracadabra, abracadabra\n[01:33.464]Choose the road on the west side\n[01:36.818]As the dust flies, watch it burn\n[01:41.057]Don't waste time on feeling\n[01:44.419]Use your passion no return\n[01:48.724]Hold me in your heart tonight\n[01:51.886]In the magic of the dark moonlight\n[01:56.270]Save me from this empty fight\n[02:01.599]In the game of life\n[02:03.524]Like a poem said by a lady in red\n[02:07.192]You hear the last few words of your life\n[02:11.055]With a haunting dance, now you're both in a trance\n[02:14.786]It's time to cast your spell on the night\n[02:19.225]Abracadabra amor-oo-na-na\n[02:22.553]Abracadabra morta-oo-gaga\n[02:26.852]Abracadabra, abra-ooh-na-na\n[02:30.110]In her tongue, she's sayin'\n[02:32.494]Death or love tonight\n[02:36.244]Abracadabra, abracadabra\n[02:40.161]Abracadabra, abracadabra\n[02:43.966]Feel the beat under your feet\n[02:45.630]The floor's on fire\n[02:47.812]Abracadabra, abracadabra\n[02:50.537]Phantom of the dancefloor, come to me\n[02:58.169]Sing for me a sinful melody\n[03:05.833]Ah-ah-ah-ah-ah, ah-ah, ah-ah\n[03:13.453]Ah-ah-ah-ah-ah, ah-ah, ah-ah\n[03:22.025]Abracadabra amor-oo-na-na\n[03:25.423]Abracadabra morta-oo-gaga\n[03:29.674]Abracadabra, abra-ooh-na-na\n[03:33.013]In her tongue, she's sayin'\n[03:35.401]Death or love tonight\n"""],
|
339 |
+
["""[00:00.27]只因你太美 baby 只因你太美 baby\n[00:08.95]只因你实在是太美 baby\n[00:13.99]只因你太美 baby\n[00:18.89]迎面走来的你让我如此蠢蠢欲动\n[00:20.88]这种感觉我从未有\n[00:21.79]Cause I got a crush on you who you\n[00:25.74]你是我的我是你的谁\n[00:28.09]再多一眼看一眼就会爆炸\n[00:30.31]再近一点靠近点快被融化\n[00:32.49]想要把你占为己有 baby bae\n[00:34.60]不管走到哪里\n[00:35.44]都会想起的人是你 you you\n[00:38.12]我应该拿你怎样\n[00:39.61]Uh 所有人都在看着你\n[00:42.36]我的心总是不安\n[00:44.18]Oh 我现在已病入膏肓\n[00:46.63]Eh oh\n[00:47.84]难道真的因你而疯狂吗\n[00:51.57]我本来不是这种人\n[00:53.59]因你变成奇怪的人\n[00:55.77]第一次呀变成这样的我\n[01:01.23]不管我怎么去否认\n[01:03.21]只因你太美 baby 只因你太美 baby\n[01:11.46]只因你实在是太美 baby\n[01:16.75]只因你太美 baby\n[01:21.09]Oh eh oh\n[01:22.82]现在确认地告诉我\n[01:25.26]Oh eh oh\n[01:27.31]你到底属于谁\n[01:29.98]Oh eh oh\n[01:31.70]现在确认地告诉我\n[01:34.45]Oh eh oh\n"""],
|
340 |
+
["""[00:19.50]也想不到要怎么问你别来无恙\n[00:25.71]世界乱的一塌糊涂可是能怎样\n[00:31.85]偶尔抬起头来还好有颗月亮可赏\n[00:38.96]太多回忆要我怎么摆进行李箱\n[00:45.22]一直没哭一直走路走灰多少太阳\n[00:51.70]因为往事没有办法悬赏 隐形在那大街小巷\n[01:00.22]剪断了它还嚣张\n[01:03.85]我的嘴在说谎 说的那么漂亮\n[01:10.07]说我早就忘了你像月一样的俏脸庞\n[01:16.51]最怕一边忙呀忙一边回想那旧时光\n[01:22.87]剪不掉的是你 带泪的脸 还真是烦\n[01:43.32]多少原因将我绑在半夜屋顶上\n[01:49.23]一直没再爱一个人如今就是这样\n[01:55.75]因为故事跟你说了一半 于是搁在所谓云端\n[02:04.21]谁忘不了谁孤单\n[02:07.79]我的心在说谎 说下去会疯狂\n[02:14.02]如果没有月亮那些日子都无妨\n[02:20.43]最怕一边忙呀忙一边想那旧时光\n[02:26.91]剪不掉的是你 带笑的苦 还真烦\n[02:33.81]我的嘴又说了谎 说的那么漂亮\n[02:39.68]以为已经忘了你的那些美像月光它剪不断\n[02:47.15]因为爱早就钻进心脏 心一跳泪就会烫\n[02:52.22]那些带泪的脸 带笑的苦 还真烦\n[02:59.28]月亮是个凶手 想你的我 是通缉犯\n[03:08.03]我有时候真的很怕望见那月光中的你\n"""]
|
341 |
],
|
342 |
|
343 |
inputs=[lrc],
|
344 |
label="Lrc Examples",
|
345 |
+
examples_per_page=4,
|
346 |
elem_id="lrc-examples-container",
|
347 |
)
|
348 |
|
|
|
437 |
|
438 |
lyrics_btn.click(
|
439 |
fn=infer_music,
|
440 |
+
inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method, preference_infer, Music_Duration],
|
441 |
outputs=audio_output
|
442 |
)
|
443 |
|
diffrhythm/infer/infer.py
CHANGED
@@ -43,6 +43,7 @@ def inference(
|
|
43 |
odeint_method,
|
44 |
pred_frames,
|
45 |
batch_infer_num,
|
|
|
46 |
chunked=True,
|
47 |
):
|
48 |
with torch.inference_mode():
|
@@ -50,6 +51,7 @@ def inference(
|
|
50 |
cond=cond,
|
51 |
text=text,
|
52 |
duration=duration,
|
|
|
53 |
style_prompt=style_prompt,
|
54 |
negative_style_prompt=negative_style_prompt,
|
55 |
steps=steps,
|
@@ -59,7 +61,8 @@ def inference(
|
|
59 |
vocal_flag=vocal_flag,
|
60 |
odeint_method=odeint_method,
|
61 |
latent_pred_segments=pred_frames,
|
62 |
-
batch_infer_num=batch_infer_num
|
|
|
63 |
)
|
64 |
|
65 |
outputs = []
|
|
|
43 |
odeint_method,
|
44 |
pred_frames,
|
45 |
batch_infer_num,
|
46 |
+
song_duration,
|
47 |
chunked=True,
|
48 |
):
|
49 |
with torch.inference_mode():
|
|
|
51 |
cond=cond,
|
52 |
text=text,
|
53 |
duration=duration,
|
54 |
+
max_duration=duration,
|
55 |
style_prompt=style_prompt,
|
56 |
negative_style_prompt=negative_style_prompt,
|
57 |
steps=steps,
|
|
|
61 |
vocal_flag=vocal_flag,
|
62 |
odeint_method=odeint_method,
|
63 |
latent_pred_segments=pred_frames,
|
64 |
+
batch_infer_num=batch_infer_num,
|
65 |
+
song_duration=song_duration
|
66 |
)
|
67 |
|
68 |
outputs = []
|
diffrhythm/infer/infer_utils.py
CHANGED
@@ -194,17 +194,21 @@ def encode_audio(audio, vae_model, chunked=False, overlap=32, chunk_size=128):
|
|
194 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
195 |
return y_final
|
196 |
|
197 |
-
def prepare_model(device):
|
198 |
# prepare cfm model
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
201 |
dit_config_path = "./diffrhythm/config/config.json"
|
202 |
with open(dit_config_path) as f:
|
203 |
model_config = json.load(f)
|
204 |
dit_model_cls = DiT
|
205 |
cfm = CFM(
|
206 |
-
transformer=dit_model_cls(**model_config["model"], max_frames=
|
207 |
num_channels=model_config["model"]['mel_dim'],
|
|
|
208 |
)
|
209 |
cfm = cfm.to(device)
|
210 |
cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
|
@@ -410,12 +414,11 @@ class CNENTokenizer:
|
|
410 |
return "|".join([self.id2phone[x - 1] for x in token])
|
411 |
|
412 |
|
413 |
-
def get_lrc_token(max_frames, text, tokenizer, device):
|
414 |
|
415 |
lyrics_shift = 0
|
416 |
sampling_rate = 44100
|
417 |
downsample_rate = 2048
|
418 |
-
max_secs = max_frames / (sampling_rate / downsample_rate)
|
419 |
|
420 |
comma_token_id = 1
|
421 |
period_token_id = 2
|
@@ -436,10 +439,15 @@ def get_lrc_token(max_frames, text, tokenizer, device):
|
|
436 |
]
|
437 |
if max_frames == 2048:
|
438 |
lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
|
|
|
|
|
|
|
439 |
|
440 |
normalized_start_time = 0.0
|
|
|
|
|
441 |
|
442 |
-
lrc = torch.zeros((
|
443 |
|
444 |
tokens_count = 0
|
445 |
last_end_pos = 0
|
@@ -455,7 +463,7 @@ def get_lrc_token(max_frames, text, tokenizer, device):
|
|
455 |
frame_shift = random.randint(int(-lyrics_shift), int(lyrics_shift))
|
456 |
|
457 |
frame_start = max(gt_frame_start - frame_shift, last_end_pos)
|
458 |
-
frame_len = min(num_tokens,
|
459 |
|
460 |
lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
|
461 |
|
@@ -466,8 +474,11 @@ def get_lrc_token(max_frames, text, tokenizer, device):
|
|
466 |
|
467 |
normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
|
468 |
normalized_start_time = normalized_start_time.half()
|
|
|
|
|
|
|
469 |
|
470 |
-
return lrc_emb, normalized_start_time
|
471 |
|
472 |
|
473 |
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
|
|
194 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
195 |
return y_final
|
196 |
|
197 |
+
def prepare_model(max_frames, device):
|
198 |
# prepare cfm model
|
199 |
+
if max_frames == 2048:
|
200 |
+
dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-1_2", filename="cfm_model.pt")
|
201 |
+
else:
|
202 |
+
dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-1_2-full", filename="cfm_model.pt")
|
203 |
+
|
204 |
dit_config_path = "./diffrhythm/config/config.json"
|
205 |
with open(dit_config_path) as f:
|
206 |
model_config = json.load(f)
|
207 |
dit_model_cls = DiT
|
208 |
cfm = CFM(
|
209 |
+
transformer=dit_model_cls(**model_config["model"], max_frames=max_frames),
|
210 |
num_channels=model_config["model"]['mel_dim'],
|
211 |
+
max_frames=max_frames
|
212 |
)
|
213 |
cfm = cfm.to(device)
|
214 |
cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
|
|
|
414 |
return "|".join([self.id2phone[x - 1] for x in token])
|
415 |
|
416 |
|
417 |
+
def get_lrc_token(max_frames, text, tokenizer, max_secs, device):
|
418 |
|
419 |
lyrics_shift = 0
|
420 |
sampling_rate = 44100
|
421 |
downsample_rate = 2048
|
|
|
422 |
|
423 |
comma_token_id = 1
|
424 |
period_token_id = 2
|
|
|
439 |
]
|
440 |
if max_frames == 2048:
|
441 |
lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
|
442 |
+
|
443 |
+
end_frame = max_frames if max_frames == 2048 else int(max_secs * (sampling_rate / downsample_rate))
|
444 |
+
end_frame = min(end_frame, max_frames)
|
445 |
|
446 |
normalized_start_time = 0.0
|
447 |
+
|
448 |
+
normalized_duration = end_frame / max_frames
|
449 |
|
450 |
+
lrc = torch.zeros((end_frame,), dtype=torch.long)
|
451 |
|
452 |
tokens_count = 0
|
453 |
last_end_pos = 0
|
|
|
463 |
frame_shift = random.randint(int(-lyrics_shift), int(lyrics_shift))
|
464 |
|
465 |
frame_start = max(gt_frame_start - frame_shift, last_end_pos)
|
466 |
+
frame_len = min(num_tokens, end_frame - frame_start)
|
467 |
|
468 |
lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
|
469 |
|
|
|
474 |
|
475 |
normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
|
476 |
normalized_start_time = normalized_start_time.half()
|
477 |
+
|
478 |
+
normalized_duration = torch.tensor(normalized_duration).unsqueeze(0).to(device)
|
479 |
+
normalized_duration = normalized_duration.half()
|
480 |
|
481 |
+
return lrc_emb, normalized_start_time, end_frame, normalized_duration
|
482 |
|
483 |
|
484 |
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
diffrhythm/model/cfm.py
CHANGED
@@ -138,6 +138,7 @@ class CFM(nn.Module):
|
|
138 |
latent_pred_segments=None,
|
139 |
vocal_flag=False,
|
140 |
odeint_method="euler",
|
|
|
141 |
batch_infer_num=5
|
142 |
):
|
143 |
self.eval()
|
@@ -208,19 +209,20 @@ class CFM(nn.Module):
|
|
208 |
negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
|
209 |
start_time = start_time.repeat(batch_infer_num)
|
210 |
fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
|
|
|
211 |
|
212 |
def fn(t, x):
|
213 |
# predict flow
|
214 |
pred = self.transformer(
|
215 |
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=False, drop_text=False, drop_prompt=False,
|
216 |
-
style_prompt=style_prompt, start_time=start_time
|
217 |
)
|
218 |
if cfg_strength < 1e-5:
|
219 |
return pred
|
220 |
|
221 |
null_pred = self.transformer(
|
222 |
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=True, drop_text=True, drop_prompt=False,
|
223 |
-
style_prompt=negative_style_prompt, start_time=start_time
|
224 |
)
|
225 |
return pred + (pred - null_pred) * cfg_strength
|
226 |
|
|
|
138 |
latent_pred_segments=None,
|
139 |
vocal_flag=False,
|
140 |
odeint_method="euler",
|
141 |
+
song_duration=None,
|
142 |
batch_infer_num=5
|
143 |
):
|
144 |
self.eval()
|
|
|
209 |
negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
|
210 |
start_time = start_time.repeat(batch_infer_num)
|
211 |
fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
|
212 |
+
song_duration = song_duration.repeat(batch_infer_num)
|
213 |
|
214 |
def fn(t, x):
|
215 |
# predict flow
|
216 |
pred = self.transformer(
|
217 |
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=False, drop_text=False, drop_prompt=False,
|
218 |
+
style_prompt=style_prompt, start_time=start_time, duration=song_duration
|
219 |
)
|
220 |
if cfg_strength < 1e-5:
|
221 |
return pred
|
222 |
|
223 |
null_pred = self.transformer(
|
224 |
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=True, drop_text=True, drop_prompt=False,
|
225 |
+
style_prompt=negative_style_prompt, start_time=start_time, duration=song_duration
|
226 |
)
|
227 |
return pred + (pred - null_pred) * cfg_strength
|
228 |
|
diffrhythm/model/dit.py
CHANGED
@@ -118,6 +118,7 @@ class DiT(nn.Module):
|
|
118 |
cond_dim = 512
|
119 |
self.time_embed = TimestepEmbedding(cond_dim)
|
120 |
self.start_time_embed = TimestepEmbedding(cond_dim)
|
|
|
121 |
if text_dim is None:
|
122 |
text_dim = mel_dim
|
123 |
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos=self.max_frames)
|
@@ -170,6 +171,7 @@ class DiT(nn.Module):
|
|
170 |
drop_prompt=False,
|
171 |
style_prompt=None, # [b d t]
|
172 |
start_time=None,
|
|
|
173 |
):
|
174 |
|
175 |
batch, seq_len = x.shape[0], x.shape[1]
|
@@ -179,7 +181,8 @@ class DiT(nn.Module):
|
|
179 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
180 |
t = self.time_embed(time)
|
181 |
s_t = self.start_time_embed(start_time)
|
182 |
-
|
|
|
183 |
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
184 |
|
185 |
if drop_prompt:
|
|
|
118 |
cond_dim = 512
|
119 |
self.time_embed = TimestepEmbedding(cond_dim)
|
120 |
self.start_time_embed = TimestepEmbedding(cond_dim)
|
121 |
+
self.duration_time_embed = TimestepEmbedding(cond_dim) if self.max_frames == 6144 else None
|
122 |
if text_dim is None:
|
123 |
text_dim = mel_dim
|
124 |
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos=self.max_frames)
|
|
|
171 |
drop_prompt=False,
|
172 |
style_prompt=None, # [b d t]
|
173 |
start_time=None,
|
174 |
+
duration=None
|
175 |
):
|
176 |
|
177 |
batch, seq_len = x.shape[0], x.shape[1]
|
|
|
181 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
182 |
t = self.time_embed(time)
|
183 |
s_t = self.start_time_embed(start_time)
|
184 |
+
d_t = self.duration_time_embed(duration) if self.max_frames == 6144 else torch.zeros_like(s_t)
|
185 |
+
c = t + s_t + d_t
|
186 |
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
187 |
|
188 |
if drop_prompt:
|