ghostai1 commited on
Commit
0b63ff0
·
verified ·
1 Parent(s): 6641398

Update barks.py

Browse files
Files changed (1) hide show
  1. barks.py +9 -10
barks.py CHANGED
@@ -1,4 +1,4 @@
1
- ```python
2
  import os
3
  import torch
4
  import torchaudio
@@ -36,8 +36,8 @@ if device != "cuda":
36
  sys.exit(1)
37
  print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
38
 
39
- # Initialize accelerator with enhanced CPU offloading
40
- accelerator = Accelerator(mixed_precision="fp16", cpu_offload=True)
41
 
42
  # Pre-run memory cleanup
43
  def aggressive_memory_cleanup():
@@ -51,7 +51,7 @@ aggressive_memory_cleanup()
51
 
52
  # 2) LOAD MODELS
53
  try:
54
- print("Loading MusicGen medium model into VRAM...")
55
  local_model_path = "./models/musicgen-medium"
56
  if not os.path.exists(local_model_path):
57
  print(f"ERROR: Local model path {local_model_path} does not exist.")
@@ -92,7 +92,7 @@ def check_vram_availability(required_gb=3.0): # Lowered threshold
92
  available_vram = total_vram - allocated_vram
93
  if available_vram < required_gb:
94
  print(f"WARNING: Low VRAM available ({available_vram:.2f} GB < {required_gb:.2f} GB required).")
95
- print("Reduce total_duration, chunk_duration, or enable more CPU offloading.")
96
  print(f"Total VRAM: {total_vram:.2f} GB, Available: {available_vram:.2f} GB")
97
  return available_vram >= required_gb
98
 
@@ -274,10 +274,10 @@ def generate_vocals(vocal_prompt: str, total_duration: int):
274
  try:
275
  print("Generating vocals with Bark...")
276
  # Move Bark model to GPU
277
- bark_model = accelerator.prepare(bark_model)
278
 
279
  # Process vocal prompt
280
- inputs = bark_processor(vocal_prompt, return_tensors="pt").to(accelerator.device)
281
 
282
  # Generate vocals with mixed precision
283
  with torch.no_grad(), autocast():
@@ -330,7 +330,7 @@ def generate_music(instrumental_prompt: str, vocal_prompt: str, cfg_scale: float
330
  np.random.seed(seed)
331
 
332
  # Move MusicGen to GPU
333
- musicgen_model = accelerator.prepare(musicgen_model)
334
 
335
  for i in range(num_chunks):
336
  chunk_prompt = instrumental_prompt
@@ -567,7 +567,7 @@ with gr.Blocks(css=css) as demo:
567
  maximum=1.0,
568
  value=0.9,
569
  step=0.05,
570
- pair_with="Keeps tokens with cumulative probability above p."
571
  )
572
  temperature = gr.Slider(
573
  label="Temperature 🔥",
@@ -693,4 +693,3 @@ try:
693
  fastapi_app.openapi_url = None
694
  except Exception:
695
  pass
696
- ```
 
1
+
2
  import os
3
  import torch
4
  import torchaudio
 
36
  sys.exit(1)
37
  print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
38
 
39
+ # Initialize accelerator without cpu_offload
40
+ accelerator = Accelerator(mixed_precision="fp16")
41
 
42
  # Pre-run memory cleanup
43
  def aggressive_memory_cleanup():
 
51
 
52
  # 2) LOAD MODELS
53
  try:
54
+ print("Loading MusicGen medium model into system RAM...")
55
  local_model_path = "./models/musicgen-medium"
56
  if not os.path.exists(local_model_path):
57
  print(f"ERROR: Local model path {local_model_path} does not exist.")
 
92
  available_vram = total_vram - allocated_vram
93
  if available_vram < required_gb:
94
  print(f"WARNING: Low VRAM available ({available_vram:.2f} GB < {required_gb:.2f} GB required).")
95
+ print("Reduce total_duration, chunk_duration, or skip vocals.")
96
  print(f"Total VRAM: {total_vram:.2f} GB, Available: {available_vram:.2f} GB")
97
  return available_vram >= required_gb
98
 
 
274
  try:
275
  print("Generating vocals with Bark...")
276
  # Move Bark model to GPU
277
+ bark_model = bark_model.to("cuda")
278
 
279
  # Process vocal prompt
280
+ inputs = bark_processor(vocal_prompt, return_tensors="pt").to("cuda")
281
 
282
  # Generate vocals with mixed precision
283
  with torch.no_grad(), autocast():
 
330
  np.random.seed(seed)
331
 
332
  # Move MusicGen to GPU
333
+ musicgen_model = musicgen_model.to("cuda")
334
 
335
  for i in range(num_chunks):
336
  chunk_prompt = instrumental_prompt
 
567
  maximum=1.0,
568
  value=0.9,
569
  step=0.05,
570
+ info="Keeps tokens with cumulative probability above p."
571
  )
572
  temperature = gr.Slider(
573
  label="Temperature 🔥",
 
693
  fastapi_app.openapi_url = None
694
  except Exception:
695
  pass