Spaces:
Running
on
Zero
Running
on
Zero
Update inference.py
Browse files- inference.py +129 -78
inference.py
CHANGED
@@ -3,17 +3,21 @@ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
|
|
3 |
|
4 |
import argparse
|
5 |
import time
|
|
|
6 |
import librosa
|
7 |
-
from tqdm.auto import tqdm
|
8 |
import sys
|
9 |
import os
|
10 |
import glob
|
11 |
import torch
|
12 |
-
import soundfile as sf
|
13 |
import torch.nn as nn
|
14 |
import numpy as np
|
15 |
-
|
16 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Colab kontrolü
|
19 |
try:
|
@@ -22,19 +26,22 @@ try:
|
|
22 |
except ImportError:
|
23 |
IS_COLAB = False
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
i18n = I18nAuto()
|
26 |
|
27 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
28 |
sys.path.append(current_dir)
|
29 |
|
30 |
from utils import demix, get_model_from_config, normalize_audio, denormalize_audio
|
31 |
-
from utils import prefer_target_instrument, apply_tta, load_start_checkpoint
|
32 |
-
|
33 |
-
import warnings
|
34 |
-
warnings.filterwarnings("ignore")
|
35 |
|
36 |
def shorten_filename(filename, max_length=30):
|
37 |
-
"""Dosya adını belirtilen maksimum uzunluğa kısaltır."""
|
38 |
base, ext = os.path.splitext(filename)
|
39 |
if len(base) <= max_length:
|
40 |
return filename
|
@@ -42,16 +49,22 @@ def shorten_filename(filename, max_length=30):
|
|
42 |
return shortened
|
43 |
|
44 |
def get_soundfile_subtype(pcm_type, is_float=False):
|
45 |
-
|
46 |
-
if is_float:
|
47 |
return 'FLOAT'
|
48 |
-
subtype_map = {
|
49 |
-
'PCM_16': 'PCM_16',
|
50 |
-
'PCM_24': 'PCM_24',
|
51 |
-
'FLOAT': 'FLOAT'
|
52 |
-
}
|
53 |
return subtype_map.get(pcm_type, 'FLOAT')
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def run_folder(model, args, config, device, verbose: bool = False, progress=None):
|
56 |
start_time = time.time()
|
57 |
model.eval()
|
@@ -60,7 +73,7 @@ def run_folder(model, args, config, device, verbose: bool = False, progress=None
|
|
60 |
sample_rate = getattr(config.audio, 'sample_rate', 44100)
|
61 |
|
62 |
logging.info(f"Total files found: {len(mixture_paths)} with sample rate: {sample_rate}")
|
63 |
-
print(
|
64 |
|
65 |
instruments = prefer_target_instrument(config)[:]
|
66 |
store_dir = args.store_dir
|
@@ -68,49 +81,65 @@ def run_folder(model, args, config, device, verbose: bool = False, progress=None
|
|
68 |
|
69 |
total_files = len(mixture_paths)
|
70 |
processed_files = 0
|
|
|
71 |
|
72 |
for path in mixture_paths:
|
73 |
try:
|
74 |
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
75 |
logging.info(f"Loaded audio: {path}, shape: {mix.shape}")
|
76 |
-
print(
|
77 |
|
78 |
-
# Dosya ilerlemesi için başlangıç güncellemesi
|
79 |
processed_files += 1
|
80 |
-
base_progress = round((
|
81 |
if progress is not None and callable(getattr(progress, '__call__', None)):
|
82 |
-
progress(base_progress / 100, desc=
|
83 |
-
update_progress_html(
|
84 |
|
85 |
mix_orig = mix.copy()
|
86 |
-
if 'normalize' in config.inference and config.inference
|
87 |
mix, norm_params = normalize_audio(mix)
|
88 |
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
|
92 |
if args.use_tta:
|
93 |
-
|
94 |
-
|
|
|
|
|
95 |
|
96 |
if args.demud_phaseremix_inst:
|
97 |
logging.info(f"Demudding track: {path}")
|
98 |
-
print(
|
99 |
instr = 'vocals' if 'vocals' in instruments else instruments[0]
|
100 |
instruments.append('instrumental_phaseremix')
|
101 |
if 'instrumental' not in instruments and 'Instrumental' not in instruments:
|
102 |
mix_modified = mix_orig - 2 * waveforms_orig[instr]
|
103 |
mix_modified_ = mix_modified.copy()
|
104 |
-
waveforms_modified = demix(
|
|
|
|
|
|
|
105 |
if args.use_tta:
|
106 |
-
waveforms_modified = apply_tta(
|
|
|
|
|
|
|
107 |
waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
|
108 |
else:
|
109 |
mix_modified = 2 * waveforms_orig[instr] - mix_orig
|
110 |
mix_modified_ = mix_modified.copy()
|
111 |
-
waveforms_modified = demix(
|
|
|
|
|
|
|
112 |
if args.use_tta:
|
113 |
-
waveforms_modified = apply_tta(
|
|
|
|
|
|
|
114 |
waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
|
115 |
|
116 |
if args.extract_instrumental:
|
@@ -119,96 +148,118 @@ def run_folder(model, args, config, device, verbose: bool = False, progress=None
|
|
119 |
if 'instrumental' not in instruments:
|
120 |
instruments.append('instrumental')
|
121 |
|
122 |
-
for instr in instruments:
|
123 |
estimates = waveforms_orig[instr]
|
124 |
-
if 'normalize' in config.inference and config.inference
|
125 |
estimates = denormalize_audio(estimates, norm_params)
|
126 |
|
127 |
is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
|
128 |
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
|
129 |
-
subtype = get_soundfile_subtype(args.pcm_type, is_float
|
130 |
|
131 |
shortened_filename = shorten_filename(os.path.basename(path))
|
132 |
output_filename = f"{shortened_filename}_{instr}.{codec}"
|
133 |
output_path = os.path.join(store_dir, output_filename)
|
134 |
sf.write(output_path, estimates.T, sr, subtype=subtype)
|
135 |
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
if progress is not None and callable(getattr(progress, '__call__', None)):
|
139 |
-
progress(file_progress / 100, desc=
|
140 |
-
update_progress_html(
|
141 |
|
142 |
except Exception as e:
|
143 |
logging.error(f"Cannot read track: {path}. Error: {str(e)}")
|
144 |
-
print(
|
|
|
145 |
continue
|
146 |
|
147 |
elapsed_time = time.time() - start_time
|
148 |
-
logging.info(f"
|
149 |
-
print(
|
150 |
|
151 |
-
# Tüm işlem tamamlandı
|
152 |
if progress is not None and callable(getattr(progress, '__call__', None)):
|
153 |
-
progress(1.0, desc="
|
154 |
-
update_progress_html("
|
155 |
|
156 |
@spaces.GPU
|
157 |
-
def proc_folder(args):
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
args = parser.parse_args(args)
|
182 |
|
183 |
device = "cpu"
|
184 |
if args.force_cpu:
|
185 |
-
|
186 |
elif torch.cuda.is_available():
|
|
|
187 |
print(i18n("cuda_available"))
|
188 |
-
device = f'cuda:{args.device_ids[0]}'
|
189 |
elif torch.backends.mps.is_available():
|
190 |
device = "mps"
|
191 |
|
|
|
192 |
print(i18n("using_device").format(device))
|
193 |
|
194 |
model_load_start_time = time.time()
|
195 |
torch.backends.cudnn.benchmark = True
|
196 |
|
197 |
-
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
if args.start_check_point
|
200 |
-
|
|
|
|
|
|
|
|
|
201 |
|
|
|
202 |
print(i18n("instruments_print").format(config.training.instruments))
|
203 |
|
204 |
-
if
|
205 |
model = nn.DataParallel(model, device_ids=args.device_ids)
|
|
|
206 |
|
207 |
model = model.to(device)
|
208 |
|
209 |
-
|
|
|
|
|
210 |
|
211 |
-
run_folder(model, args, config, device, verbose=False)
|
|
|
212 |
|
213 |
if __name__ == "__main__":
|
214 |
-
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import argparse
|
5 |
import time
|
6 |
+
import logging
|
7 |
import librosa
|
|
|
8 |
import sys
|
9 |
import os
|
10 |
import glob
|
11 |
import torch
|
|
|
12 |
import torch.nn as nn
|
13 |
import numpy as np
|
14 |
+
import soundfile as sf
|
15 |
import spaces
|
16 |
+
import warnings
|
17 |
+
warnings.filterwarnings("ignore")
|
18 |
+
|
19 |
+
# Loglama ayarları
|
20 |
+
logging.basicConfig(level=logging.DEBUG, filename='utils.log', format='%(asctime)s - %(levelname)s - %(message)s')
|
21 |
|
22 |
# Colab kontrolü
|
23 |
try:
|
|
|
26 |
except ImportError:
|
27 |
IS_COLAB = False
|
28 |
|
29 |
+
# i18n yer tutucu
|
30 |
+
class I18nAuto:
|
31 |
+
def __call__(self, message):
|
32 |
+
return message
|
33 |
+
def format(self, message, *args):
|
34 |
+
return message.format(*args)
|
35 |
+
|
36 |
i18n = I18nAuto()
|
37 |
|
38 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
39 |
sys.path.append(current_dir)
|
40 |
|
41 |
from utils import demix, get_model_from_config, normalize_audio, denormalize_audio
|
42 |
+
from utils import prefer_target_instrument, apply_tta, load_start_checkpoint
|
|
|
|
|
|
|
43 |
|
44 |
def shorten_filename(filename, max_length=30):
|
|
|
45 |
base, ext = os.path.splitext(filename)
|
46 |
if len(base) <= max_length:
|
47 |
return filename
|
|
|
49 |
return shortened
|
50 |
|
51 |
def get_soundfile_subtype(pcm_type, is_float=False):
|
52 |
+
if pcm_type == 'FLOAT' or is_float:
|
|
|
53 |
return 'FLOAT'
|
54 |
+
subtype_map = {'PCM_16': 'PCM_16', 'PCM_24': 'PCM_24', 'FLOAT': 'FLOAT'}
|
|
|
|
|
|
|
|
|
55 |
return subtype_map.get(pcm_type, 'FLOAT')
|
56 |
|
57 |
+
def update_progress_html(progress_label, progress_percent):
|
58 |
+
progress_percent = min(max(round(progress_percent), 0), 100)
|
59 |
+
return f"""
|
60 |
+
<div id="custom-progress" style="margin-top: 10px;">
|
61 |
+
<div style="font-size: 1rem; color: #C0C0C0; margin-bottom: 5px;" id="progress-label">{progress_label}</div>
|
62 |
+
<div style="width: 100%; background-color: #444; border-radius: 5px; overflow: hidden;">
|
63 |
+
<div id="progress-bar" style="width: {progress_percent}%; height: 20px; background-color: #6e8efb; transition: width 0.3s; max-width: 100%;"></div>
|
64 |
+
</div>
|
65 |
+
</div>
|
66 |
+
"""
|
67 |
+
|
68 |
def run_folder(model, args, config, device, verbose: bool = False, progress=None):
|
69 |
start_time = time.time()
|
70 |
model.eval()
|
|
|
73 |
sample_rate = getattr(config.audio, 'sample_rate', 44100)
|
74 |
|
75 |
logging.info(f"Total files found: {len(mixture_paths)} with sample rate: {sample_rate}")
|
76 |
+
print(i18n("total_files_found").format(len(mixture_paths), sample_rate))
|
77 |
|
78 |
instruments = prefer_target_instrument(config)[:]
|
79 |
store_dir = args.store_dir
|
|
|
81 |
|
82 |
total_files = len(mixture_paths)
|
83 |
processed_files = 0
|
84 |
+
base_progress_per_file = 100 / total_files if total_files > 0 else 100
|
85 |
|
86 |
for path in mixture_paths:
|
87 |
try:
|
88 |
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
89 |
logging.info(f"Loaded audio: {path}, shape: {mix.shape}")
|
90 |
+
print(i18n("loaded_audio").format(path, mix.shape))
|
91 |
|
|
|
92 |
processed_files += 1
|
93 |
+
base_progress = round((processed_files - 1) * base_progress_per_file)
|
94 |
if progress is not None and callable(getattr(progress, '__call__', None)):
|
95 |
+
progress(base_progress / 100, desc=i18n("processing_file").format(processed_files, total_files))
|
96 |
+
update_progress_html(i18n("processing_file").format(processed_files, total_files), base_progress)
|
97 |
|
98 |
mix_orig = mix.copy()
|
99 |
+
if 'normalize' in config.inference and config.inference.get('normalize', False):
|
100 |
mix, norm_params = normalize_audio(mix)
|
101 |
|
102 |
+
waveforms_orig = demix(
|
103 |
+
config, model, mix, device, model_type=args.model_type, pbar=False,
|
104 |
+
progress=lambda p, desc: progress((base_progress + p * 50) / 100, desc=desc) if progress else None
|
105 |
+
)
|
106 |
|
107 |
if args.use_tta:
|
108 |
+
waveforms_orig = apply_tta(
|
109 |
+
config, model, mix, waveforms_orig, device, args.model_type,
|
110 |
+
progress=lambda p, desc: progress((base_progress + 50 + p * 20) / 100, desc=desc) if progress else None
|
111 |
+
)
|
112 |
|
113 |
if args.demud_phaseremix_inst:
|
114 |
logging.info(f"Demudding track: {path}")
|
115 |
+
print(i18n("demudding_track").format(path))
|
116 |
instr = 'vocals' if 'vocals' in instruments else instruments[0]
|
117 |
instruments.append('instrumental_phaseremix')
|
118 |
if 'instrumental' not in instruments and 'Instrumental' not in instruments:
|
119 |
mix_modified = mix_orig - 2 * waveforms_orig[instr]
|
120 |
mix_modified_ = mix_modified.copy()
|
121 |
+
waveforms_modified = demix(
|
122 |
+
config, model, mix_modified, device, model_type=args.model_type, pbar=False,
|
123 |
+
progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None
|
124 |
+
)
|
125 |
if args.use_tta:
|
126 |
+
waveforms_modified = apply_tta(
|
127 |
+
config, model, mix_modified, waveforms_modified, device, args.model_type,
|
128 |
+
progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None
|
129 |
+
)
|
130 |
waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
|
131 |
else:
|
132 |
mix_modified = 2 * waveforms_orig[instr] - mix_orig
|
133 |
mix_modified_ = mix_modified.copy()
|
134 |
+
waveforms_modified = demix(
|
135 |
+
config, model, mix_modified, device, model_type=args.model_type, pbar=False,
|
136 |
+
progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None
|
137 |
+
)
|
138 |
if args.use_tta:
|
139 |
+
waveforms_modified = apply_tta(
|
140 |
+
config, model, mix_modified, waveforms_orig, device, args.model_type,
|
141 |
+
progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None
|
142 |
+
)
|
143 |
waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
|
144 |
|
145 |
if args.extract_instrumental:
|
|
|
148 |
if 'instrumental' not in instruments:
|
149 |
instruments.append('instrumental')
|
150 |
|
151 |
+
for i, instr in enumerate(instruments):
|
152 |
estimates = waveforms_orig[instr]
|
153 |
+
if 'normalize' in config.inference and config.inference.get('normalize', False):
|
154 |
estimates = denormalize_audio(estimates, norm_params)
|
155 |
|
156 |
is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
|
157 |
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
|
158 |
+
subtype = get_soundfile_subtype(args.pcm_type, is_float=is_float)
|
159 |
|
160 |
shortened_filename = shorten_filename(os.path.basename(path))
|
161 |
output_filename = f"{shortened_filename}_{instr}.{codec}"
|
162 |
output_path = os.path.join(store_dir, output_filename)
|
163 |
sf.write(output_path, estimates.T, sr, subtype=subtype)
|
164 |
|
165 |
+
save_progress = round(base_progress + 95 + (i / len(instruments)) * 5)
|
166 |
+
if progress is not None and callable(getattr('progress', '__call__', None)):
|
167 |
+
progress(save_progress / 100, desc=i18n("saving_output").format(instr, processed_files, total_files))
|
168 |
+
update_progress_html(i18n("saving_output").format(instr, processed_files, total_files), save_progress)
|
169 |
+
|
170 |
+
file_progress = round(processed_files * base_progress_per_file)
|
171 |
if progress is not None and callable(getattr(progress, '__call__', None)):
|
172 |
+
progress(file_progress / 100, desc=i18n("completed_file").format(processed_files, total_files))
|
173 |
+
update_progress_html(i18n("completed_file").format(processed_files, total_files), file_progress)
|
174 |
|
175 |
except Exception as e:
|
176 |
logging.error(f"Cannot read track: {path}. Error: {str(e)}")
|
177 |
+
print(i18n("cannot_read_track").format(path))
|
178 |
+
print(i18n("error_message").format(str(e)))
|
179 |
continue
|
180 |
|
181 |
elapsed_time = time.time() - start_time
|
182 |
+
logging.info(f"Processing time: {elapsed_time:.2f} seconds")
|
183 |
+
print(i18n("elapsed_time").format(elapsed_time))
|
184 |
|
|
|
185 |
if progress is not None and callable(getattr(progress, '__call__', None)):
|
186 |
+
progress(1.0, desc=i18n("processing_complete"))
|
187 |
+
update_progress_html(i18n("processing_complete"), 100)
|
188 |
|
189 |
@spaces.GPU
|
190 |
+
def proc_folder(args=None, progress=None):
|
191 |
+
try:
|
192 |
+
parser = argparse.ArgumentParser(description=i18n("proc_folder_description"))
|
193 |
+
parser.add_argument("--model_type", type=str, default='melod_band_roformer', help=i18n("model_type_help"))
|
194 |
+
parser.add_argument("--config_path", type=str, required=True, help=i18n("config_path_help"))
|
195 |
+
parser.add_argument("--start_check_point", type=str, required=True, help=i18n("start_checkpoint_help"))
|
196 |
+
parser.add_argument("--input_folder", type=str, required=True, help=i18n("input_folder_help"))
|
197 |
+
parser.add_argument("--store_dir", type=str, required=True, help=i18n("store_dir_help"))
|
198 |
+
parser.add_argument("--chunk_size", type=int, default=352800, help=i18n("chunk_size_help"))
|
199 |
+
parser.add_argument("--overlap", type=int, default=2, help=i18n("overlap_help"))
|
200 |
+
parser.add_argument("--export_format", type=str, default='wav FLOAT', choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], help=i18n("export_format_help"))
|
201 |
+
parser.add_argument("--demud_phaseremix_inst", action='store_true', help=i18n("demud_phaseremix_help"))
|
202 |
+
parser.add_argument("--extract_instrumental", "action='store_true', help=i18n("extract_instrumental_help"))
|
203 |
+
parser.add_argument("--use_tta", action='store_true', help=i18n("use_tta_help"))
|
204 |
+
parser.add_argument("--flac_file", action='store_true', help=i18n("flac_file_help"))
|
205 |
+
parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', help=i18n("pcm_type_help"))
|
206 |
+
parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help=i18n("device_ids_help"))
|
207 |
+
parser.add_argument("--force_cpu", action='store_true', help=i18n("force_cpu_help"))
|
208 |
+
parser.add_argument("--lora_checkpoint", type=str, default='', help=i18n("lora_checkpoint_help"))
|
209 |
+
|
210 |
+
args = parser.parse_args(args if args else [])
|
211 |
+
except Exception as e:
|
212 |
+
logging.error(f"Argument parsing failed: {str(e)}")
|
213 |
+
raise ValueError(f"Invalid command-line arguments: {str(e)}")
|
|
|
214 |
|
215 |
device = "cpu"
|
216 |
if args.force_cpu:
|
217 |
+
logging.info("Forced to use CPU")
|
218 |
elif torch.cuda.is_available():
|
219 |
+
logging.info("CUDA available")
|
220 |
print(i18n("cuda_available"))
|
221 |
+
device = f'cuda:{args.device_ids[0]}'
|
222 |
elif torch.backends.mps.is_available():
|
223 |
device = "mps"
|
224 |
|
225 |
+
logging.info(f"Using device: {device}")
|
226 |
print(i18n("using_device").format(device))
|
227 |
|
228 |
model_load_start_time = time.time()
|
229 |
torch.backends.cudnn.benchmark = True
|
230 |
|
231 |
+
try:
|
232 |
+
model, config = get_model_from_config(args.model_type, args.config_path)
|
233 |
+
except Exception as e:
|
234 |
+
logging.error(f"Failed to load model: {str(e)}")
|
235 |
+
raise
|
236 |
|
237 |
+
if args.start_check_point:
|
238 |
+
try:
|
239 |
+
load_start_checkpoint(args, model, type_='inference')
|
240 |
+
except Exception as e:
|
241 |
+
logging.error(f"Failed to load checkpoint: {str(e)}")
|
242 |
+
raise
|
243 |
|
244 |
+
logging.info(f"Instruments: {config.training.instruments}")
|
245 |
print(i18n("instruments_print").format(config.training.instruments))
|
246 |
|
247 |
+
if len(args.device_ids) > 1 and not args.force_cpu:
|
248 |
model = nn.DataParallel(model, device_ids=args.device_ids)
|
249 |
+
logging.info(f"Using DataParallel with devices: {args.device_ids}")
|
250 |
|
251 |
model = model.to(device)
|
252 |
|
253 |
+
elapsed_time = time.time() - model_load_start_time
|
254 |
+
logging.info(f"Model load time: {elapsed_time:.2f} seconds")
|
255 |
+
print(i18n("model_load_time").format(elapsed_time))
|
256 |
|
257 |
+
run_folder(model, args, config, device, verbose=False, progress=progress)
|
258 |
+
return "Processing completed"
|
259 |
|
260 |
if __name__ == "__main__":
|
261 |
+
try:
|
262 |
+
proc_folder(None)
|
263 |
+
except Exception as e:
|
264 |
+
logging.error(f"Main execution failed: {str(e)}")
|
265 |
+
raise
|