Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		hainazhu
		
	commited on
		
		
					Commit 
							
							·
						
						c4aaa82
	
1
								Parent(s):
							
							208580f
								
add separator.py
Browse files- separator.py +50 -0
    	
        separator.py
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torchaudio
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from third_party.demucs.models.pretrained import get_model_from_yaml
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class Separator(torch.nn.Module):
         | 
| 8 | 
            +
                def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
         | 
| 9 | 
            +
                    super().__init__()
         | 
| 10 | 
            +
                    if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
         | 
| 11 | 
            +
                        self.device = torch.device(f"cuda:{gpu_id}")
         | 
| 12 | 
            +
                    else:
         | 
| 13 | 
            +
                        self.device = torch.device("cpu")
         | 
| 14 | 
            +
                    self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def init_demucs_model(self, model_path, config_path):
         | 
| 17 | 
            +
                    model = get_model_from_yaml(config_path, model_path)
         | 
| 18 | 
            +
                    model.to(self.device)
         | 
| 19 | 
            +
                    model.eval()
         | 
| 20 | 
            +
                    return model
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
                def load_audio(self, f):
         | 
| 23 | 
            +
                    a, fs = torchaudio.load(f)
         | 
| 24 | 
            +
                    if (fs != 48000):
         | 
| 25 | 
            +
                        a = torchaudio.functional.resample(a, fs, 48000)
         | 
| 26 | 
            +
                    if a.shape[-1] >= 48000*10:
         | 
| 27 | 
            +
                        a = a[..., :48000*10]
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        a = torch.cat([a, a], -1)
         | 
| 30 | 
            +
                    return a[:, 0:48000*10]
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                def run(self, audio_path, output_dir='tmp', ext=".flac"):
         | 
| 33 | 
            +
                    os.makedirs(output_dir, exist_ok=True)
         | 
| 34 | 
            +
                    name, _ = os.path.splitext(os.path.split(audio_path)[-1])
         | 
| 35 | 
            +
                    output_paths = []
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    for stem in self.demucs_model.sources:
         | 
| 38 | 
            +
                        output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
         | 
| 39 | 
            +
                        if os.path.exists(output_path):
         | 
| 40 | 
            +
                            output_paths.append(output_path)
         | 
| 41 | 
            +
                    if len(output_paths) == 1:  # 4
         | 
| 42 | 
            +
                        vocal_path = output_paths[0]
         | 
| 43 | 
            +
                    else:
         | 
| 44 | 
            +
                        drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
         | 
| 45 | 
            +
                        for path in [drums_path, bass_path, other_path]:
         | 
| 46 | 
            +
                            os.remove(path)
         | 
| 47 | 
            +
                    full_audio = self.load_audio(audio_path)
         | 
| 48 | 
            +
                    vocal_audio = self.load_audio(vocal_path)
         | 
| 49 | 
            +
                    bgm_audio = full_audio - vocal_audio
         | 
| 50 | 
            +
                    return full_audio, vocal_audio, bgm_audio
         | 
 
			
