Update barks.py
Browse files
barks.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
|
@@ -36,8 +36,8 @@ if device != "cuda":
|
|
| 36 |
sys.exit(1)
|
| 37 |
print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
|
| 38 |
|
| 39 |
-
# Initialize accelerator
|
| 40 |
-
accelerator = Accelerator(mixed_precision="fp16"
|
| 41 |
|
| 42 |
# Pre-run memory cleanup
|
| 43 |
def aggressive_memory_cleanup():
|
|
@@ -51,7 +51,7 @@ aggressive_memory_cleanup()
|
|
| 51 |
|
| 52 |
# 2) LOAD MODELS
|
| 53 |
try:
|
| 54 |
-
print("Loading MusicGen medium model into
|
| 55 |
local_model_path = "./models/musicgen-medium"
|
| 56 |
if not os.path.exists(local_model_path):
|
| 57 |
print(f"ERROR: Local model path {local_model_path} does not exist.")
|
|
@@ -92,7 +92,7 @@ def check_vram_availability(required_gb=3.0): # Lowered threshold
|
|
| 92 |
available_vram = total_vram - allocated_vram
|
| 93 |
if available_vram < required_gb:
|
| 94 |
print(f"WARNING: Low VRAM available ({available_vram:.2f} GB < {required_gb:.2f} GB required).")
|
| 95 |
-
print("Reduce total_duration, chunk_duration, or
|
| 96 |
print(f"Total VRAM: {total_vram:.2f} GB, Available: {available_vram:.2f} GB")
|
| 97 |
return available_vram >= required_gb
|
| 98 |
|
|
@@ -274,10 +274,10 @@ def generate_vocals(vocal_prompt: str, total_duration: int):
|
|
| 274 |
try:
|
| 275 |
print("Generating vocals with Bark...")
|
| 276 |
# Move Bark model to GPU
|
| 277 |
-
bark_model =
|
| 278 |
|
| 279 |
# Process vocal prompt
|
| 280 |
-
inputs = bark_processor(vocal_prompt, return_tensors="pt").to(
|
| 281 |
|
| 282 |
# Generate vocals with mixed precision
|
| 283 |
with torch.no_grad(), autocast():
|
|
@@ -330,7 +330,7 @@ def generate_music(instrumental_prompt: str, vocal_prompt: str, cfg_scale: float
|
|
| 330 |
np.random.seed(seed)
|
| 331 |
|
| 332 |
# Move MusicGen to GPU
|
| 333 |
-
musicgen_model =
|
| 334 |
|
| 335 |
for i in range(num_chunks):
|
| 336 |
chunk_prompt = instrumental_prompt
|
|
@@ -567,7 +567,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 567 |
maximum=1.0,
|
| 568 |
value=0.9,
|
| 569 |
step=0.05,
|
| 570 |
-
|
| 571 |
)
|
| 572 |
temperature = gr.Slider(
|
| 573 |
label="Temperature 🔥",
|
|
@@ -693,4 +693,3 @@ try:
|
|
| 693 |
fastapi_app.openapi_url = None
|
| 694 |
except Exception:
|
| 695 |
pass
|
| 696 |
-
```
|
|
|
|
| 1 |
+
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
|
|
|
| 36 |
sys.exit(1)
|
| 37 |
print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
|
| 38 |
|
| 39 |
+
# Initialize accelerator without cpu_offload
|
| 40 |
+
accelerator = Accelerator(mixed_precision="fp16")
|
| 41 |
|
| 42 |
# Pre-run memory cleanup
|
| 43 |
def aggressive_memory_cleanup():
|
|
|
|
| 51 |
|
| 52 |
# 2) LOAD MODELS
|
| 53 |
try:
|
| 54 |
+
print("Loading MusicGen medium model into system RAM...")
|
| 55 |
local_model_path = "./models/musicgen-medium"
|
| 56 |
if not os.path.exists(local_model_path):
|
| 57 |
print(f"ERROR: Local model path {local_model_path} does not exist.")
|
|
|
|
| 92 |
available_vram = total_vram - allocated_vram
|
| 93 |
if available_vram < required_gb:
|
| 94 |
print(f"WARNING: Low VRAM available ({available_vram:.2f} GB < {required_gb:.2f} GB required).")
|
| 95 |
+
print("Reduce total_duration, chunk_duration, or skip vocals.")
|
| 96 |
print(f"Total VRAM: {total_vram:.2f} GB, Available: {available_vram:.2f} GB")
|
| 97 |
return available_vram >= required_gb
|
| 98 |
|
|
|
|
| 274 |
try:
|
| 275 |
print("Generating vocals with Bark...")
|
| 276 |
# Move Bark model to GPU
|
| 277 |
+
bark_model = bark_model.to("cuda")
|
| 278 |
|
| 279 |
# Process vocal prompt
|
| 280 |
+
inputs = bark_processor(vocal_prompt, return_tensors="pt").to("cuda")
|
| 281 |
|
| 282 |
# Generate vocals with mixed precision
|
| 283 |
with torch.no_grad(), autocast():
|
|
|
|
| 330 |
np.random.seed(seed)
|
| 331 |
|
| 332 |
# Move MusicGen to GPU
|
| 333 |
+
musicgen_model = musicgen_model.to("cuda")
|
| 334 |
|
| 335 |
for i in range(num_chunks):
|
| 336 |
chunk_prompt = instrumental_prompt
|
|
|
|
| 567 |
maximum=1.0,
|
| 568 |
value=0.9,
|
| 569 |
step=0.05,
|
| 570 |
+
info="Keeps tokens with cumulative probability above p."
|
| 571 |
)
|
| 572 |
temperature = gr.Slider(
|
| 573 |
label="Temperature 🔥",
|
|
|
|
| 693 |
fastapi_app.openapi_url = None
|
| 694 |
except Exception:
|
| 695 |
pass
|
|
|