Spaces:
Running
Running
jhj0517
commited on
Commit
·
da960ac
1
Parent(s):
b14146c
Add `offload()`
Browse files
modules/uvr/music_separator.py
CHANGED
|
@@ -4,6 +4,7 @@ import torchaudio
|
|
| 4 |
import soundfile as sf
|
| 5 |
import os
|
| 6 |
import torch
|
|
|
|
| 7 |
|
| 8 |
from uvr.models import MDX, Demucs, VrNetwork, MDXC
|
| 9 |
|
|
@@ -30,6 +31,14 @@ class MusicSeparator:
|
|
| 30 |
model_name: str = "UVR-MDX-NET-Inst_1",
|
| 31 |
device: Optional[str] = None,
|
| 32 |
segment_size: int = 256):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
if device is None:
|
| 34 |
device = self.device
|
| 35 |
|
|
@@ -61,7 +70,10 @@ class MusicSeparator:
|
|
| 61 |
"split": True
|
| 62 |
}
|
| 63 |
|
| 64 |
-
if self.model is None or
|
|
|
|
|
|
|
|
|
|
| 65 |
self.update_model(
|
| 66 |
model_name=model_name,
|
| 67 |
device=device,
|
|
@@ -84,4 +96,13 @@ class MusicSeparator:
|
|
| 84 |
|
| 85 |
@staticmethod
|
| 86 |
def get_device():
|
| 87 |
-
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import soundfile as sf
|
| 5 |
import os
|
| 6 |
import torch
|
| 7 |
+
import gc
|
| 8 |
|
| 9 |
from uvr.models import MDX, Demucs, VrNetwork, MDXC
|
| 10 |
|
|
|
|
| 31 |
model_name: str = "UVR-MDX-NET-Inst_1",
|
| 32 |
device: Optional[str] = None,
|
| 33 |
segment_size: int = 256):
|
| 34 |
+
"""
|
| 35 |
+
Update model with the given model name
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model_name (str): Model name.
|
| 39 |
+
device (str): Device to use for the model.
|
| 40 |
+
segment_size (int): Segment size for the prediction.
|
| 41 |
+
"""
|
| 42 |
if device is None:
|
| 43 |
device = self.device
|
| 44 |
|
|
|
|
| 70 |
"split": True
|
| 71 |
}
|
| 72 |
|
| 73 |
+
if (self.model is None or
|
| 74 |
+
self.current_model_size != model_name or
|
| 75 |
+
self.model_config != model_config or
|
| 76 |
+
self.audio_info.sample_rate != sample_rate):
|
| 77 |
self.update_model(
|
| 78 |
model_name=model_name,
|
| 79 |
device=device,
|
|
|
|
| 96 |
|
| 97 |
@staticmethod
|
| 98 |
def get_device():
|
| 99 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 100 |
+
|
| 101 |
+
def offload(self):
|
| 102 |
+
if self.model is not None:
|
| 103 |
+
del self.model
|
| 104 |
+
self.model = None
|
| 105 |
+
if self.device == "cuda":
|
| 106 |
+
torch.cuda.empty_cache()
|
| 107 |
+
gc.collect()
|
| 108 |
+
self.audio_info = None
|