speed commited on
Commit
d41e8a4
·
verified ·
1 Parent(s): 7e91f03

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +175 -1
README.md CHANGED
@@ -10,18 +10,192 @@ pipeline_tag: audio-to-audio
10
 
11
  # Llama-Mimi-1.3B
12
 
 
13
 
14
- ## Overview
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  ## How to Use
18
 
 
 
19
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
21
  ```
22
 
23
 
24
  ## Citation
25
 
26
  ```
 
 
 
 
 
 
 
 
 
27
  ```
 
10
 
11
  # Llama-Mimi-1.3B
12
 
13
+ [📃Paper](https://arxiv.org/abs/2509.14882) | [🧑‍💻Code](https://github.com/llm-jp/llama-mimi) | [🗣️Demo](https://speed1313.github.io/llama-mimi/)
14
 
15
+ <img src="https://speed1313.github.io/llama-mimi/data/llama-mimi.svg" width="50%"/>
16
+
17
+ ## Introduction
18
+ Llama-Mimi is a speech language model that uses a unified tokenizer (Mimi) and a single Transformer decoder (Llama) to jointly model sequences of interleaved semantic and acoustic tokens.
19
+ Trained on ~240k hours of English audio, Llama-Mimi achieves state-of-the-art performance in acoustic consistency on [SALMon](https://arxiv.org/abs/2409.07437) and effectively preserves speaker identity.
20
+ Visit our [demo site](https://speed1313.github.io/llama-mimi/) to hear generated speech samples.
21
+
22
+
23
+ ## Models
24
+ - Llama-Mimi-1.3B, https://huggingface.co/llm-jp/Llama-Mimi-1.3b
25
+ - Llama-Mimi-8B, https://huggingface.co/llm-jp/Llama-Mimi-8b
26
 
27
 
28
  ## How to Use
29
 
30
+ Generate audio continuations from a given audio prompt.
31
+
32
  ```python
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
+ import torch
35
+ from transformers import MimiModel, AutoFeatureExtractor
36
+ from transformers import StoppingCriteria
37
+ import random
38
+ import numpy as np
39
+ import torchaudio
40
+ import soundfile as sf
41
+ import re
42
+
43
+
44
+ def text_to_audio_values(
45
+ text: str,
46
+ num_quantizers: int,
47
+ output_file: str,
48
+ audio_tokenizer,
49
+ feature_extractor,
50
+ ):
51
+ # Extract (val, idx) pairs from the <val_idx> format in the text
52
+ matches = re.findall(r"<(\d+)_(\d+)>", text)
53
+ vals = []
54
+
55
+ for i in range(0, len(matches), num_quantizers):
56
+ chunk = matches[i : i + num_quantizers]
57
+ if len(chunk) < num_quantizers:
58
+ break
59
+ indices = [int(idx) for _, idx in chunk]
60
+ if indices == list(range(num_quantizers)):
61
+ vals.extend(int(val) for val, _ in chunk)
62
+ else:
63
+ break
64
+
65
+ vals = vals[: len(vals) - len(vals) % num_quantizers]
66
+ tensor_bt4 = torch.tensor(vals).reshape(1, -1, num_quantizers) # (B, T, 4)
67
+ tensor_b4t = tensor_bt4.transpose(1, 2) # (B, 4, T)
68
+
69
+ audio_values = audio_tokenizer.decode(tensor_b4t)[0]
70
+
71
+ sf.write(
72
+ output_file,
73
+ audio_values[0][0].detach().cpu().numpy(),
74
+ feature_extractor.sampling_rate,
75
+ )
76
+
77
+
78
+ def audio_array_to_text(
79
+ audio_array: torch.tensor,
80
+ audio_tokenizer,
81
+ feature_extractor,
82
+ num_quantizers: int,
83
+ max_seconds: int = 20,
84
+ ) -> str:
85
+ # truncate the audio array to the expected length
86
+ if audio_array.shape[-1] > max_seconds * feature_extractor.sampling_rate:
87
+ audio_array = audio_array[: max_seconds * feature_extractor.sampling_rate]
88
+ #
89
+ inputs = feature_extractor(
90
+ raw_audio=audio_array,
91
+ sampling_rate=feature_extractor.sampling_rate,
92
+ return_tensors="pt",
93
+ ).to(audio_tokenizer.device)
94
+ with torch.no_grad():
95
+ encoder_outputs = audio_tokenizer.encode(
96
+ inputs["input_values"],
97
+ inputs["padding_mask"],
98
+ num_quantizers=num_quantizers,
99
+ )
100
+ flatten_audio_codes = encoder_outputs.audio_codes.transpose(1, 2).reshape(-1)
101
+ assert flatten_audio_codes.numel() % num_quantizers == 0
102
+
103
+ steps = []
104
+ for i in range(0, flatten_audio_codes.numel(), num_quantizers):
105
+ group = [
106
+ f"<{flatten_audio_codes[i + j].item()}_{j}>"
107
+ for j in range(num_quantizers)
108
+ ]
109
+ steps.append(group)
110
+
111
+ parts = [tok for step in steps for tok in step]
112
+
113
+ text = "".join(parts)
114
+
115
+ del inputs, encoder_outputs, flatten_audio_codes
116
+ torch.cuda.empty_cache()
117
+ return f"<audio>{text}</audio>"
118
+
119
+
120
+ def set_determinism(seed: int = 42) -> None:
121
+ random.seed(seed)
122
+ np.random.seed(seed)
123
+ torch.manual_seed(seed)
124
+
125
+ class StopOnAudioEnd(StoppingCriteria):
126
+ def __init__(self, tokenizer):
127
+ self.tokenizer = tokenizer
128
+ self.target_text = "</audio>"
129
+ self.target_ids = tokenizer(
130
+ self.target_text, add_special_tokens=False
131
+ ).input_ids
132
+
133
+ def __call__(self, input_ids, scores, **kwargs):
134
+ if len(input_ids[0]) < len(self.target_ids):
135
+ return False
136
+ return input_ids[0][-len(self.target_ids) :].tolist() == self.target_ids
137
+
138
+ set_determinism()
139
+
140
+ temperature = 0.8
141
+ top_k = 30
142
+ do_sample = True
143
+ max_length = 1024
144
+ device = "cuda" if torch.cuda.is_available() else "cpu"
145
+ model_id = "llm-jp/Llama-Mimi-1.3B"
146
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).eval().to(device)
147
+ num_quantizers = model.config.num_quantizers
148
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
149
+ audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi")
150
+ feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
151
+ stopping_criteria = StopOnAudioEnd(tokenizer)
152
+
153
+ audio_file = "assets/great_day_gt.wav"
154
+ waveform, sample_rate = torchaudio.load(audio_file)
155
+ if sample_rate != feature_extractor.sampling_rate:
156
+ waveform = torchaudio.transforms.Resample(sample_rate, feature_extractor.sampling_rate)(waveform)
157
+ sample_rate = feature_extractor.sampling_rate
158
+ prompt_array = waveform.squeeze().cpu().numpy()
159
+
160
+ text = audio_array_to_text(
161
+ prompt_array, audio_tokenizer, feature_extractor, num_quantizers
162
+ )
163
+
164
+ text = text.replace("</audio>", "")
165
+ inputs = tokenizer(text, return_tensors="pt").to(device)
166
+
167
+ with torch.no_grad():
168
+ generated = model.generate(
169
+ **inputs,
170
+ max_length=max_length,
171
+ do_sample=do_sample,
172
+ temperature=temperature,
173
+ top_k=top_k,
174
+ stopping_criteria=[stopping_criteria],
175
+ )
176
+
177
+ generated_text = tokenizer.decode(generated[0])
178
 
179
+ text_to_audio_values(
180
+ generated_text,
181
+ num_quantizers=num_quantizers,
182
+ output_file="output.wav",
183
+ audio_tokenizer=audio_tokenizer,
184
+ feature_extractor=feature_extractor,
185
+ )
186
  ```
187
 
188
 
189
  ## Citation
190
 
191
  ```
192
+ @misc{sugiura2025llamamimispeechlanguagemodels,
193
+ title={Llama-Mimi: Speech Language Models with Interleaved Semantic and Acoustic Tokens},
194
+ author={Issa Sugiura and Shuhei Kurita and Yusuke Oda and Ryuichiro Higashinaka},
195
+ year={2025},
196
+ eprint={2509.14882},
197
+ archivePrefix={arXiv},
198
+ primaryClass={cs.CL},
199
+ url={https://arxiv.org/abs/2509.14882},
200
+ }
201
  ```