Update README.md
Browse files
README.md
CHANGED
@@ -49,6 +49,7 @@ To infer our MusicGen models, you primarily use the `elastic_models.transformers
|
|
49 |
```python
|
50 |
import torch
|
51 |
import scipy.io.wavfile
|
|
|
52 |
from transformers import AutoProcessor
|
53 |
from elastic_models.transformers import MusicgenForConditionalGeneration
|
54 |
|
@@ -57,7 +58,7 @@ elastic_mode = "S"
|
|
57 |
|
58 |
prompt = "A groovy funk bassline with a tight drum beat"
|
59 |
output_wav_path = "generated_audio_elastic_S.wav"
|
60 |
-
hf_token = "
|
61 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
|
63 |
processor = AutoProcessor.from_pretrained(model_name_hf, token=hf_token)
|
@@ -65,7 +66,7 @@ processor = AutoProcessor.from_pretrained(model_name_hf, token=hf_token)
|
|
65 |
model = MusicgenForConditionalGeneration.from_pretrained(
|
66 |
model_name_hf,
|
67 |
token=hf_token,
|
68 |
-
torch_dtype=torch.float16,
|
69 |
mode=elastic_mode,
|
70 |
device=device,
|
71 |
__full_patch=True,
|
@@ -83,17 +84,16 @@ print(f"Generating audio for: {prompt}...")
|
|
83 |
generate_kwargs = {"do_sample": True, "guidance_scale": 3.0, "max_new_tokens": 256, "cache_implementation": "paged"}
|
84 |
|
85 |
audio_values = model.generate(**inputs, **generate_kwargs)
|
86 |
-
audio_values_np = audio_values.cpu().numpy().squeeze()
|
87 |
|
88 |
sampling_rate = model.config.audio_encoder.sampling_rate
|
89 |
scipy.io.wavfile.write(output_wav_path, rate=sampling_rate, data=audio_values_np)
|
90 |
print(f"Audio saved to {output_wav_path}")
|
91 |
-
|
92 |
```
|
93 |
|
94 |
__System requirements:__
|
95 |
-
* GPUs: NVIDIA H100, L40S
|
96 |
-
* CPU: AMD, Intel
|
97 |
* Python: 3.8-3.11 (check dependencies for specific versions)
|
98 |
|
99 |
To work with our elastic models and compilation tools, you\'ll need to install `elastic_models` and `qlip` libraries from TheStage:
|
|
|
49 |
```python
|
50 |
import torch
|
51 |
import scipy.io.wavfile
|
52 |
+
|
53 |
from transformers import AutoProcessor
|
54 |
from elastic_models.transformers import MusicgenForConditionalGeneration
|
55 |
|
|
|
58 |
|
59 |
prompt = "A groovy funk bassline with a tight drum beat"
|
60 |
output_wav_path = "generated_audio_elastic_S.wav"
|
61 |
+
hf_token = "YOUR_TOKEN"
|
62 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
|
64 |
processor = AutoProcessor.from_pretrained(model_name_hf, token=hf_token)
|
|
|
66 |
model = MusicgenForConditionalGeneration.from_pretrained(
|
67 |
model_name_hf,
|
68 |
token=hf_token,
|
69 |
+
torch_dtype=torch.float16, # Or float32, matching compilation
|
70 |
mode=elastic_mode,
|
71 |
device=device,
|
72 |
__full_patch=True,
|
|
|
84 |
generate_kwargs = {"do_sample": True, "guidance_scale": 3.0, "max_new_tokens": 256, "cache_implementation": "paged"}
|
85 |
|
86 |
audio_values = model.generate(**inputs, **generate_kwargs)
|
87 |
+
audio_values_np = audio_values.to(torch.float32).cpu().numpy().squeeze()
|
88 |
|
89 |
sampling_rate = model.config.audio_encoder.sampling_rate
|
90 |
scipy.io.wavfile.write(output_wav_path, rate=sampling_rate, data=audio_values_np)
|
91 |
print(f"Audio saved to {output_wav_path}")
|
|
|
92 |
```
|
93 |
|
94 |
__System requirements:__
|
95 |
+
* GPUs: NVIDIA H100, L40S.
|
96 |
+
* CPU: AMD, Intel
|
97 |
* Python: 3.8-3.11 (check dependencies for specific versions)
|
98 |
|
99 |
To work with our elastic models and compilation tools, you\'ll need to install `elastic_models` and `qlip` libraries from TheStage:
|