WeixuanYuan commited on
Commit
c0c17c4
1 Parent(s): ae1bdf7

Upload build_instrument.py

Browse files
webUI/natural_language_guided/build_instrument.py CHANGED
@@ -4,13 +4,42 @@ import torch
4
  import gradio as gr
5
  import mido
6
  from io import BytesIO
7
- import pyrubberband as pyrb
8
 
9
  from model.DiffSynthSampler import DiffSynthSampler
10
  from tools import adsr_envelope, adjust_audio_length
11
  from webUI.natural_language_guided.track_maker import DiffSynth
12
  from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT, phase_to_Gradio_image, \
13
  spectrogram_to_Gradio_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def get_build_instrument_module(gradioWebUI, virtual_instruments_state):
@@ -154,7 +183,8 @@ def get_build_instrument_module(gradioWebUI, virtual_instruments_state):
154
  sample_rate, signal = virtual_instrument["signal"]
155
 
156
  s = 3 / duration
157
- applied_signal = pyrb.time_stretch(signal, sample_rate, s)
 
158
  applied_signal = adjust_audio_length(applied_signal, int((duration+1) * sample_rate), sample_rate, sample_rate)
159
 
160
  D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :]
 
4
  import gradio as gr
5
  import mido
6
  from io import BytesIO
 
7
 
8
  from model.DiffSynthSampler import DiffSynthSampler
9
  from tools import adsr_envelope, adjust_audio_length
10
  from webUI.natural_language_guided.track_maker import DiffSynth
11
  from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT, phase_to_Gradio_image, \
12
  spectrogram_to_Gradio_image
13
+ import torchaudio.transforms as transforms
14
+
15
+
16
+ def time_stretch_audio(waveform, sample_rate, stretch_factor):
17
+ # 如果输入是 numpy 数组,则转换为 torch.Tensor
18
+ if isinstance(waveform, np.ndarray):
19
+ waveform = torch.from_numpy(waveform)
20
+
21
+ # 确保 waveform 的类型为 torch.float32
22
+ waveform = waveform.to(torch.float32)
23
+
24
+ # 设置 STFT 参数
25
+ n_fft = 2048 # STFT 窗口大小
26
+ hop_length = n_fft // 4 # STFT 的 hop length 设置为 n_fft 的四分之一
27
+
28
+ # 计算短时傅里叶变换 (STFT)
29
+ stft = torch.stft(waveform, n_fft=n_fft, hop_length=hop_length, return_complex=True)
30
+
31
+ # 创建 TimeStretch 变换
32
+ time_stretch = transforms.TimeStretch(hop_length=hop_length, n_freq=1025, fixed_rate=False)
33
+
34
+ print(stft.shape)
35
+ # 应用时间伸缩
36
+ stretched_stft = time_stretch(stft, stretch_factor)
37
+
38
+ # 将 STFT 转换回时域波形
39
+ stretched_waveform = torch.istft(stretched_stft, n_fft=n_fft, hop_length=hop_length)
40
+
41
+ # 返回处理后的 waveform,转换为 numpy 数组
42
+ return stretched_waveform.detach().numpy()
43
 
44
 
45
  def get_build_instrument_module(gradioWebUI, virtual_instruments_state):
 
183
  sample_rate, signal = virtual_instrument["signal"]
184
 
185
  s = 3 / duration
186
+ # applied_signal = pyrb.time_stretch(signal, sample_rate, s)
187
+ applied_signal = time_stretch_audio(signal, sample_rate, s)
188
  applied_signal = adjust_audio_length(applied_signal, int((duration+1) * sample_rate), sample_rate, sample_rate)
189
 
190
  D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :]