psynote123 commited on
Commit
e7d22c8
·
verified ·
1 Parent(s): 5e02d06

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -49,6 +49,7 @@ To infer our MusicGen models, you primarily use the `elastic_models.transformers
49
  ```python
50
  import torch
51
  import scipy.io.wavfile
 
52
  from transformers import AutoProcessor
53
  from elastic_models.transformers import MusicgenForConditionalGeneration
54
 
@@ -57,7 +58,7 @@ elastic_mode = "S"
57
 
58
  prompt = "A groovy funk bassline with a tight drum beat"
59
  output_wav_path = "generated_audio_elastic_S.wav"
60
- hf_token = "YOUR_HF_TOKEN"
61
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
 
63
  processor = AutoProcessor.from_pretrained(model_name_hf, token=hf_token)
@@ -65,7 +66,7 @@ processor = AutoProcessor.from_pretrained(model_name_hf, token=hf_token)
65
  model = MusicgenForConditionalGeneration.from_pretrained(
66
  model_name_hf,
67
  token=hf_token,
68
- torch_dtype=torch.float16,
69
  mode=elastic_mode,
70
  device=device,
71
  __full_patch=True,
@@ -83,17 +84,16 @@ print(f"Generating audio for: {prompt}...")
83
  generate_kwargs = {"do_sample": True, "guidance_scale": 3.0, "max_new_tokens": 256, "cache_implementation": "paged"}
84
 
85
  audio_values = model.generate(**inputs, **generate_kwargs)
86
- audio_values_np = audio_values.cpu().numpy().squeeze()
87
 
88
  sampling_rate = model.config.audio_encoder.sampling_rate
89
  scipy.io.wavfile.write(output_wav_path, rate=sampling_rate, data=audio_values_np)
90
  print(f"Audio saved to {output_wav_path}")
91
-
92
  ```
93
 
94
  __System requirements:__
95
- * GPUs: NVIDIA H100, L40S (recommended for compiled models).
96
- * CPU: AMD, Intel (for running processor/tokenizer, inference on CPU is slow for MusicGen)
97
  * Python: 3.8-3.11 (check dependencies for specific versions)
98
 
99
  To work with our elastic models and compilation tools, you\'ll need to install `elastic_models` and `qlip` libraries from TheStage:
 
49
  ```python
50
  import torch
51
  import scipy.io.wavfile
52
+
53
  from transformers import AutoProcessor
54
  from elastic_models.transformers import MusicgenForConditionalGeneration
55
 
 
58
 
59
  prompt = "A groovy funk bassline with a tight drum beat"
60
  output_wav_path = "generated_audio_elastic_S.wav"
61
+ hf_token = "YOUR_TOKEN"
62
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
 
64
  processor = AutoProcessor.from_pretrained(model_name_hf, token=hf_token)
 
66
  model = MusicgenForConditionalGeneration.from_pretrained(
67
  model_name_hf,
68
  token=hf_token,
69
+ torch_dtype=torch.float16, # Or float32, matching compilation
70
  mode=elastic_mode,
71
  device=device,
72
  __full_patch=True,
 
84
  generate_kwargs = {"do_sample": True, "guidance_scale": 3.0, "max_new_tokens": 256, "cache_implementation": "paged"}
85
 
86
  audio_values = model.generate(**inputs, **generate_kwargs)
87
+ audio_values_np = audio_values.to(torch.float32).cpu().numpy().squeeze()
88
 
89
  sampling_rate = model.config.audio_encoder.sampling_rate
90
  scipy.io.wavfile.write(output_wav_path, rate=sampling_rate, data=audio_values_np)
91
  print(f"Audio saved to {output_wav_path}")
 
92
  ```
93
 
94
  __System requirements:__
95
+ * GPUs: NVIDIA H100, L40S.
96
+ * CPU: AMD, Intel
97
  * Python: 3.8-3.11 (check dependencies for specific versions)
98
 
99
  To work with our elastic models and compilation tools, you\'ll need to install `elastic_models` and `qlip` libraries from TheStage: