ASesYusuf1 commited on
Commit
f092faf
·
verified ·
1 Parent(s): f9b565a

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +129 -78
inference.py CHANGED
@@ -3,17 +3,21 @@ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
3
 
4
  import argparse
5
  import time
 
6
  import librosa
7
- from tqdm.auto import tqdm
8
  import sys
9
  import os
10
  import glob
11
  import torch
12
- import soundfile as sf
13
  import torch.nn as nn
14
  import numpy as np
15
- from assets.i18n.i18n import I18nAuto
16
  import spaces
 
 
 
 
 
17
 
18
  # Colab kontrolü
19
  try:
@@ -22,19 +26,22 @@ try:
22
  except ImportError:
23
  IS_COLAB = False
24
 
 
 
 
 
 
 
 
25
  i18n = I18nAuto()
26
 
27
  current_dir = os.path.dirname(os.path.abspath(__file__))
28
  sys.path.append(current_dir)
29
 
30
  from utils import demix, get_model_from_config, normalize_audio, denormalize_audio
31
- from utils import prefer_target_instrument, apply_tta, load_start_checkpoint, load_lora_weights
32
-
33
- import warnings
34
- warnings.filterwarnings("ignore")
35
 
36
  def shorten_filename(filename, max_length=30):
37
- """Dosya adını belirtilen maksimum uzunluğa kısaltır."""
38
  base, ext = os.path.splitext(filename)
39
  if len(base) <= max_length:
40
  return filename
@@ -42,16 +49,22 @@ def shorten_filename(filename, max_length=30):
42
  return shortened
43
 
44
  def get_soundfile_subtype(pcm_type, is_float=False):
45
- """PCM türüne göre uygun soundfile alt türünü belirler."""
46
- if is_float:
47
  return 'FLOAT'
48
- subtype_map = {
49
- 'PCM_16': 'PCM_16',
50
- 'PCM_24': 'PCM_24',
51
- 'FLOAT': 'FLOAT'
52
- }
53
  return subtype_map.get(pcm_type, 'FLOAT')
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  def run_folder(model, args, config, device, verbose: bool = False, progress=None):
56
  start_time = time.time()
57
  model.eval()
@@ -60,7 +73,7 @@ def run_folder(model, args, config, device, verbose: bool = False, progress=None
60
  sample_rate = getattr(config.audio, 'sample_rate', 44100)
61
 
62
  logging.info(f"Total files found: {len(mixture_paths)} with sample rate: {sample_rate}")
63
- print(f"Total files found: {len(mixture_paths)} with sample rate: {sample_rate}")
64
 
65
  instruments = prefer_target_instrument(config)[:]
66
  store_dir = args.store_dir
@@ -68,49 +81,65 @@ def run_folder(model, args, config, device, verbose: bool = False, progress=None
68
 
69
  total_files = len(mixture_paths)
70
  processed_files = 0
 
71
 
72
  for path in mixture_paths:
73
  try:
74
  mix, sr = librosa.load(path, sr=sample_rate, mono=False)
75
  logging.info(f"Loaded audio: {path}, shape: {mix.shape}")
76
- print(f"Loaded audio: {path}, shape: {mix.shape}")
77
 
78
- # Dosya ilerlemesi için başlangıç güncellemesi
79
  processed_files += 1
80
- base_progress = round(((processed_files - 1) / total_files) * 100) # Önceki dosyalar
81
  if progress is not None and callable(getattr(progress, '__call__', None)):
82
- progress(base_progress / 100, desc=f"Processing file {processed_files}/{total_files}")
83
- update_progress_html(f"Processing file {processed_files}/{total_files}", base_progress)
84
 
85
  mix_orig = mix.copy()
86
- if 'normalize' in config.inference and config.inference['normalize']:
87
  mix, norm_params = normalize_audio(mix)
88
 
89
- # demix fonksiyonuna progress nesnesini ilet
90
- waveforms_orig = demix(config, model, mix, device, model_type=args.model_type, pbar=False, progress=progress)
 
 
91
 
92
  if args.use_tta:
93
- # apply_tta fonksiyonuna progress nesnesini ilet
94
- waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type, progress=progress)
 
 
95
 
96
  if args.demud_phaseremix_inst:
97
  logging.info(f"Demudding track: {path}")
98
- print(f"Demudding track: {path}")
99
  instr = 'vocals' if 'vocals' in instruments else instruments[0]
100
  instruments.append('instrumental_phaseremix')
101
  if 'instrumental' not in instruments and 'Instrumental' not in instruments:
102
  mix_modified = mix_orig - 2 * waveforms_orig[instr]
103
  mix_modified_ = mix_modified.copy()
104
- waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type, pbar=False, progress=progress)
 
 
 
105
  if args.use_tta:
106
- waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type, progress=progress)
 
 
 
107
  waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
108
  else:
109
  mix_modified = 2 * waveforms_orig[instr] - mix_orig
110
  mix_modified_ = mix_modified.copy()
111
- waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type, pbar=False, progress=progress)
 
 
 
112
  if args.use_tta:
113
- waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type, progress=progress)
 
 
 
114
  waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
115
 
116
  if args.extract_instrumental:
@@ -119,96 +148,118 @@ def run_folder(model, args, config, device, verbose: bool = False, progress=None
119
  if 'instrumental' not in instruments:
120
  instruments.append('instrumental')
121
 
122
- for instr in instruments:
123
  estimates = waveforms_orig[instr]
124
- if 'normalize' in config.inference and config.inference['normalize']:
125
  estimates = denormalize_audio(estimates, norm_params)
126
 
127
  is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
128
  codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
129
- subtype = get_soundfile_subtype(args.pcm_type, is_float) if codec == 'flac' else get_soundfile_subtype('FLOAT', is_float)
130
 
131
  shortened_filename = shorten_filename(os.path.basename(path))
132
  output_filename = f"{shortened_filename}_{instr}.{codec}"
133
  output_path = os.path.join(store_dir, output_filename)
134
  sf.write(output_path, estimates.T, sr, subtype=subtype)
135
 
136
- # Dosya tamamlandı, ilerleme güncellemesi
137
- file_progress = round((processed_files / total_files) * 100)
 
 
 
 
138
  if progress is not None and callable(getattr(progress, '__call__', None)):
139
- progress(file_progress / 100, desc=f"Completed file {processed_files}/{total_files}")
140
- update_progress_html(f"Completed file {processed_files}/{total_files}", file_progress)
141
 
142
  except Exception as e:
143
  logging.error(f"Cannot read track: {path}. Error: {str(e)}")
144
- print(f"Cannot read track: {path}. Error: {str(e)}")
 
145
  continue
146
 
147
  elapsed_time = time.time() - start_time
148
- logging.info(f"Elapsed time: {elapsed_time:.2f} seconds")
149
- print(f"Elapsed time: {elapsed_time:.2f} seconds")
150
 
151
- # Tüm işlem tamamlandı
152
  if progress is not None and callable(getattr(progress, '__call__', None)):
153
- progress(1.0, desc="Processing complete")
154
- update_progress_html("Processing complete", 100)
155
 
156
  @spaces.GPU
157
- def proc_folder(args):
158
- parser = argparse.ArgumentParser(description=i18n("proc_folder_description"))
159
- parser.add_argument("--model_type", type=str, default='mdx23c', help=i18n("model_type_help"))
160
- parser.add_argument("--config_path", type=str, help=i18n("config_path_help"))
161
- parser.add_argument("--demud_phaseremix_inst", action='store_true', help=i18n("demud_phaseremix_help"))
162
- parser.add_argument("--start_check_point", type=str, default='', help=i18n("start_checkpoint_help"))
163
- parser.add_argument("--input_folder", type=str, help=i18n("input_folder_help"))
164
- parser.add_argument("--audio_path", type=str, help=i18n("audio_path_help"))
165
- parser.add_argument("--store_dir", type=str, default="", help=i18n("store_dir_help"))
166
- parser.add_argument("--device_ids", nargs='+', type=int, default=0, help=i18n("device_ids_help"))
167
- parser.add_argument("--extract_instrumental", action='store_true', help=i18n("extract_instrumental_help"))
168
- parser.add_argument("--disable_detailed_pbar", action='store_true', help=i18n("disable_detailed_pbar_help"))
169
- parser.add_argument("--force_cpu", action='store_true', help=i18n("force_cpu_help"))
170
- parser.add_argument("--flac_file", action='store_true', help=i18n("flac_file_help"))
171
- parser.add_argument("--export_format", type=str, choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], default='flac PCM_24', help=i18n("export_format_help"))
172
- parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', help=i18n("pcm_type_help"))
173
- parser.add_argument("--use_tta", action='store_true', help=i18n("use_tta_help"))
174
- parser.add_argument("--lora_checkpoint", type=str, default='', help=i18n("lora_checkpoint_help"))
175
- parser.add_argument("--chunk_size", type=int, default=1000000, help="Inference chunk size")
176
- parser.add_argument("--overlap", type=int, default=4, help="Inference overlap factor")
177
-
178
- if args is None:
179
- args = parser.parse_args()
180
- else:
181
- args = parser.parse_args(args)
182
 
183
  device = "cpu"
184
  if args.force_cpu:
185
- device = "cpu"
186
  elif torch.cuda.is_available():
 
187
  print(i18n("cuda_available"))
188
- device = f'cuda:{args.device_ids[0]}' if type(args.device_ids) == list else f'cuda:{args.device_ids}'
189
  elif torch.backends.mps.is_available():
190
  device = "mps"
191
 
 
192
  print(i18n("using_device").format(device))
193
 
194
  model_load_start_time = time.time()
195
  torch.backends.cudnn.benchmark = True
196
 
197
- model, config = get_model_from_config(args.model_type, args.config_path)
 
 
 
 
198
 
199
- if args.start_check_point != '':
200
- load_start_checkpoint(args, model, type_='inference')
 
 
 
 
201
 
 
202
  print(i18n("instruments_print").format(config.training.instruments))
203
 
204
- if type(args.device_ids) == list and len(args.device_ids) > 1 and not args.force_cpu:
205
  model = nn.DataParallel(model, device_ids=args.device_ids)
 
206
 
207
  model = model.to(device)
208
 
209
- print(i18n("model_load_time").format(time.time() - model_load_start_time))
 
 
210
 
211
- run_folder(model, args, config, device, verbose=False)
 
212
 
213
  if __name__ == "__main__":
214
- proc_folder(None)
 
 
 
 
 
3
 
4
  import argparse
5
  import time
6
+ import logging
7
  import librosa
 
8
  import sys
9
  import os
10
  import glob
11
  import torch
 
12
  import torch.nn as nn
13
  import numpy as np
14
+ import soundfile as sf
15
  import spaces
16
+ import warnings
17
+ warnings.filterwarnings("ignore")
18
+
19
+ # Loglama ayarları
20
+ logging.basicConfig(level=logging.DEBUG, filename='utils.log', format='%(asctime)s - %(levelname)s - %(message)s')
21
 
22
  # Colab kontrolü
23
  try:
 
26
  except ImportError:
27
  IS_COLAB = False
28
 
29
+ # i18n yer tutucu
30
+ class I18nAuto:
31
+ def __call__(self, message):
32
+ return message
33
+ def format(self, message, *args):
34
+ return message.format(*args)
35
+
36
  i18n = I18nAuto()
37
 
38
  current_dir = os.path.dirname(os.path.abspath(__file__))
39
  sys.path.append(current_dir)
40
 
41
  from utils import demix, get_model_from_config, normalize_audio, denormalize_audio
42
+ from utils import prefer_target_instrument, apply_tta, load_start_checkpoint
 
 
 
43
 
44
  def shorten_filename(filename, max_length=30):
 
45
  base, ext = os.path.splitext(filename)
46
  if len(base) <= max_length:
47
  return filename
 
49
  return shortened
50
 
51
  def get_soundfile_subtype(pcm_type, is_float=False):
52
+ if pcm_type == 'FLOAT' or is_float:
 
53
  return 'FLOAT'
54
+ subtype_map = {'PCM_16': 'PCM_16', 'PCM_24': 'PCM_24', 'FLOAT': 'FLOAT'}
 
 
 
 
55
  return subtype_map.get(pcm_type, 'FLOAT')
56
 
57
+ def update_progress_html(progress_label, progress_percent):
58
+ progress_percent = min(max(round(progress_percent), 0), 100)
59
+ return f"""
60
+ <div id="custom-progress" style="margin-top: 10px;">
61
+ <div style="font-size: 1rem; color: #C0C0C0; margin-bottom: 5px;" id="progress-label">{progress_label}</div>
62
+ <div style="width: 100%; background-color: #444; border-radius: 5px; overflow: hidden;">
63
+ <div id="progress-bar" style="width: {progress_percent}%; height: 20px; background-color: #6e8efb; transition: width 0.3s; max-width: 100%;"></div>
64
+ </div>
65
+ </div>
66
+ """
67
+
68
  def run_folder(model, args, config, device, verbose: bool = False, progress=None):
69
  start_time = time.time()
70
  model.eval()
 
73
  sample_rate = getattr(config.audio, 'sample_rate', 44100)
74
 
75
  logging.info(f"Total files found: {len(mixture_paths)} with sample rate: {sample_rate}")
76
+ print(i18n("total_files_found").format(len(mixture_paths), sample_rate))
77
 
78
  instruments = prefer_target_instrument(config)[:]
79
  store_dir = args.store_dir
 
81
 
82
  total_files = len(mixture_paths)
83
  processed_files = 0
84
+ base_progress_per_file = 100 / total_files if total_files > 0 else 100
85
 
86
  for path in mixture_paths:
87
  try:
88
  mix, sr = librosa.load(path, sr=sample_rate, mono=False)
89
  logging.info(f"Loaded audio: {path}, shape: {mix.shape}")
90
+ print(i18n("loaded_audio").format(path, mix.shape))
91
 
 
92
  processed_files += 1
93
+ base_progress = round((processed_files - 1) * base_progress_per_file)
94
  if progress is not None and callable(getattr(progress, '__call__', None)):
95
+ progress(base_progress / 100, desc=i18n("processing_file").format(processed_files, total_files))
96
+ update_progress_html(i18n("processing_file").format(processed_files, total_files), base_progress)
97
 
98
  mix_orig = mix.copy()
99
+ if 'normalize' in config.inference and config.inference.get('normalize', False):
100
  mix, norm_params = normalize_audio(mix)
101
 
102
+ waveforms_orig = demix(
103
+ config, model, mix, device, model_type=args.model_type, pbar=False,
104
+ progress=lambda p, desc: progress((base_progress + p * 50) / 100, desc=desc) if progress else None
105
+ )
106
 
107
  if args.use_tta:
108
+ waveforms_orig = apply_tta(
109
+ config, model, mix, waveforms_orig, device, args.model_type,
110
+ progress=lambda p, desc: progress((base_progress + 50 + p * 20) / 100, desc=desc) if progress else None
111
+ )
112
 
113
  if args.demud_phaseremix_inst:
114
  logging.info(f"Demudding track: {path}")
115
+ print(i18n("demudding_track").format(path))
116
  instr = 'vocals' if 'vocals' in instruments else instruments[0]
117
  instruments.append('instrumental_phaseremix')
118
  if 'instrumental' not in instruments and 'Instrumental' not in instruments:
119
  mix_modified = mix_orig - 2 * waveforms_orig[instr]
120
  mix_modified_ = mix_modified.copy()
121
+ waveforms_modified = demix(
122
+ config, model, mix_modified, device, model_type=args.model_type, pbar=False,
123
+ progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None
124
+ )
125
  if args.use_tta:
126
+ waveforms_modified = apply_tta(
127
+ config, model, mix_modified, waveforms_modified, device, args.model_type,
128
+ progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None
129
+ )
130
  waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
131
  else:
132
  mix_modified = 2 * waveforms_orig[instr] - mix_orig
133
  mix_modified_ = mix_modified.copy()
134
+ waveforms_modified = demix(
135
+ config, model, mix_modified, device, model_type=args.model_type, pbar=False,
136
+ progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None
137
+ )
138
  if args.use_tta:
139
+ waveforms_modified = apply_tta(
140
+ config, model, mix_modified, waveforms_orig, device, args.model_type,
141
+ progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None
142
+ )
143
  waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
144
 
145
  if args.extract_instrumental:
 
148
  if 'instrumental' not in instruments:
149
  instruments.append('instrumental')
150
 
151
+ for i, instr in enumerate(instruments):
152
  estimates = waveforms_orig[instr]
153
+ if 'normalize' in config.inference and config.inference.get('normalize', False):
154
  estimates = denormalize_audio(estimates, norm_params)
155
 
156
  is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
157
  codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
158
+ subtype = get_soundfile_subtype(args.pcm_type, is_float=is_float)
159
 
160
  shortened_filename = shorten_filename(os.path.basename(path))
161
  output_filename = f"{shortened_filename}_{instr}.{codec}"
162
  output_path = os.path.join(store_dir, output_filename)
163
  sf.write(output_path, estimates.T, sr, subtype=subtype)
164
 
165
+ save_progress = round(base_progress + 95 + (i / len(instruments)) * 5)
166
+ if progress is not None and callable(getattr('progress', '__call__', None)):
167
+ progress(save_progress / 100, desc=i18n("saving_output").format(instr, processed_files, total_files))
168
+ update_progress_html(i18n("saving_output").format(instr, processed_files, total_files), save_progress)
169
+
170
+ file_progress = round(processed_files * base_progress_per_file)
171
  if progress is not None and callable(getattr(progress, '__call__', None)):
172
+ progress(file_progress / 100, desc=i18n("completed_file").format(processed_files, total_files))
173
+ update_progress_html(i18n("completed_file").format(processed_files, total_files), file_progress)
174
 
175
  except Exception as e:
176
  logging.error(f"Cannot read track: {path}. Error: {str(e)}")
177
+ print(i18n("cannot_read_track").format(path))
178
+ print(i18n("error_message").format(str(e)))
179
  continue
180
 
181
  elapsed_time = time.time() - start_time
182
+ logging.info(f"Processing time: {elapsed_time:.2f} seconds")
183
+ print(i18n("elapsed_time").format(elapsed_time))
184
 
 
185
  if progress is not None and callable(getattr(progress, '__call__', None)):
186
+ progress(1.0, desc=i18n("processing_complete"))
187
+ update_progress_html(i18n("processing_complete"), 100)
188
 
189
  @spaces.GPU
190
+ def proc_folder(args=None, progress=None):
191
+ try:
192
+ parser = argparse.ArgumentParser(description=i18n("proc_folder_description"))
193
+ parser.add_argument("--model_type", type=str, default='melod_band_roformer', help=i18n("model_type_help"))
194
+ parser.add_argument("--config_path", type=str, required=True, help=i18n("config_path_help"))
195
+ parser.add_argument("--start_check_point", type=str, required=True, help=i18n("start_checkpoint_help"))
196
+ parser.add_argument("--input_folder", type=str, required=True, help=i18n("input_folder_help"))
197
+ parser.add_argument("--store_dir", type=str, required=True, help=i18n("store_dir_help"))
198
+ parser.add_argument("--chunk_size", type=int, default=352800, help=i18n("chunk_size_help"))
199
+ parser.add_argument("--overlap", type=int, default=2, help=i18n("overlap_help"))
200
+ parser.add_argument("--export_format", type=str, default='wav FLOAT', choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], help=i18n("export_format_help"))
201
+ parser.add_argument("--demud_phaseremix_inst", action='store_true', help=i18n("demud_phaseremix_help"))
202
+ parser.add_argument("--extract_instrumental", "action='store_true', help=i18n("extract_instrumental_help"))
203
+ parser.add_argument("--use_tta", action='store_true', help=i18n("use_tta_help"))
204
+ parser.add_argument("--flac_file", action='store_true', help=i18n("flac_file_help"))
205
+ parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', help=i18n("pcm_type_help"))
206
+ parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help=i18n("device_ids_help"))
207
+ parser.add_argument("--force_cpu", action='store_true', help=i18n("force_cpu_help"))
208
+ parser.add_argument("--lora_checkpoint", type=str, default='', help=i18n("lora_checkpoint_help"))
209
+
210
+ args = parser.parse_args(args if args else [])
211
+ except Exception as e:
212
+ logging.error(f"Argument parsing failed: {str(e)}")
213
+ raise ValueError(f"Invalid command-line arguments: {str(e)}")
 
214
 
215
  device = "cpu"
216
  if args.force_cpu:
217
+ logging.info("Forced to use CPU")
218
  elif torch.cuda.is_available():
219
+ logging.info("CUDA available")
220
  print(i18n("cuda_available"))
221
+ device = f'cuda:{args.device_ids[0]}'
222
  elif torch.backends.mps.is_available():
223
  device = "mps"
224
 
225
+ logging.info(f"Using device: {device}")
226
  print(i18n("using_device").format(device))
227
 
228
  model_load_start_time = time.time()
229
  torch.backends.cudnn.benchmark = True
230
 
231
+ try:
232
+ model, config = get_model_from_config(args.model_type, args.config_path)
233
+ except Exception as e:
234
+ logging.error(f"Failed to load model: {str(e)}")
235
+ raise
236
 
237
+ if args.start_check_point:
238
+ try:
239
+ load_start_checkpoint(args, model, type_='inference')
240
+ except Exception as e:
241
+ logging.error(f"Failed to load checkpoint: {str(e)}")
242
+ raise
243
 
244
+ logging.info(f"Instruments: {config.training.instruments}")
245
  print(i18n("instruments_print").format(config.training.instruments))
246
 
247
+ if len(args.device_ids) > 1 and not args.force_cpu:
248
  model = nn.DataParallel(model, device_ids=args.device_ids)
249
+ logging.info(f"Using DataParallel with devices: {args.device_ids}")
250
 
251
  model = model.to(device)
252
 
253
+ elapsed_time = time.time() - model_load_start_time
254
+ logging.info(f"Model load time: {elapsed_time:.2f} seconds")
255
+ print(i18n("model_load_time").format(elapsed_time))
256
 
257
+ run_folder(model, args, config, device, verbose=False, progress=progress)
258
+ return "Processing completed"
259
 
260
  if __name__ == "__main__":
261
+ try:
262
+ proc_folder(None)
263
+ except Exception as e:
264
+ logging.error(f"Main execution failed: {str(e)}")
265
+ raise