|
|
--- |
|
|
datasets: |
|
|
- LibriLight |
|
|
language: |
|
|
- en |
|
|
library_name: transformers |
|
|
license: apache-2.0 |
|
|
pipeline_tag: audio-to-audio |
|
|
tags: |
|
|
- audio |
|
|
- speech |
|
|
- autoregressive |
|
|
- transformers |
|
|
- custom_code |
|
|
pretty_name: AuriStream1B |
|
|
--- |
|
|
|
|
|
# AuriStream-1B |
|
|
|
|
|
[📚 Paper](https://huggingface.co/papers/2508.11598) - [🌐 Project Page](https://tukoresearch.github.io/auristream-speech/) |
|
|
|
|
|
**AuriStream** is a biologically-inspired, GPT-style autoregressive Transformer trained to predict tokens from the speech stream (denoted as **cochlear tokens**). These cochlear tokens are discrete codes produced by a companion “WavCoch” tokenizer (a model trained to predict the time-frequency cochleagram from a waveform, with a LFQ bottleneck for token read-out). AuriStream utilizes a long context window of (~20 s, ~4096 tokens) and is trained on **LibriLight (~60k hours)** for **500k steps**. It learns meaningful representations about e.g. phoneme/word identity and can predict future tokens to generate **speech continuations**. Inputs are cochlear **token IDs**; use it with a WavCoch tokenizer for audio -> tokens. |
|
|
|
|
|
--- |
|
|
|
|
|
## Installation |
|
|
|
|
|
```bash |
|
|
pip install -U torch torchaudio transformers |
|
|
``` |
|
|
|
|
|
This model uses custom code; when loading from Hugging Face, pass `trust_remote_code=True`. |
|
|
|
|
|
--- |
|
|
|
|
|
## Use case 1) Get hidden state embeddings for an audio waveform |
|
|
|
|
|
```python |
|
|
import torch, torchaudio |
|
|
from transformers import AutoModel |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# 1) Load the WavCoch tokenizer (audio -> token IDs) |
|
|
quantizer = AutoModel.from_pretrained( |
|
|
"TuKoResearch/WavCochV8192", trust_remote_code=True |
|
|
).to(device).eval() |
|
|
|
|
|
# 2) Load the AuriStream LM (tokens -> hidden states / next-token prediction) |
|
|
lm = AutoModel.from_pretrained( |
|
|
"TuKoResearch/AuriStream1B_librilight_ckpt500k", trust_remote_code=True |
|
|
).to(device).eval() |
|
|
|
|
|
# 3) Read an audio file (mono, 16 kHz recommended) |
|
|
wav, sr = torchaudio.load("sample.wav") |
|
|
|
|
|
if wav.size(0) > 1: # stereo -> mono |
|
|
wav = wav.mean(dim=0, keepdim=True) |
|
|
if sr != 16_000: |
|
|
wav = torchaudio.transforms.Resample(sr, 16_000)(wav) |
|
|
sr = 16_000 |
|
|
|
|
|
# 4) Quantize the audio to obtain cochlear token IDs |
|
|
with torch.no_grad(): |
|
|
# The quantizer forward method expects (B, 1, T); returns (B, L) |
|
|
token_ids = quantizer(wav.unsqueeze(0).to(device))['input_ids'] # (1, L) |
|
|
|
|
|
# 5) Forward pass to obtain hidden states |
|
|
with torch.no_grad(): |
|
|
out = lm(token_ids, output_hidden_states=True) |
|
|
last_layer = out["hidden_states"][-1] # (1, T, D) |
|
|
last_layer_mean = last_layer.mean(dim=1) # time mean-pool -> (1, D) |
|
|
|
|
|
print("Mean-pooled embedding shape:", last_layer_mean.shape) |
|
|
``` |
|
|
|
|
|
**Notes** |
|
|
|
|
|
* `output_hidden_states=True` returns all layers. |
|
|
* For phoneme/word segments, slice the time axis before pooling. |
|
|
|
|
|
--- |
|
|
|
|
|
## Use case 2) Generate a speech continuation (cochlear token prediction) |
|
|
|
|
|
```python |
|
|
import torch, torchaudio |
|
|
from transformers import AutoModel |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# WavCoch tokenizer (audio -> tokens) |
|
|
quantizer = AutoModel.from_pretrained( |
|
|
"TuKoResearch/WavCochV8192", trust_remote_code=True |
|
|
).to(device).eval() |
|
|
|
|
|
# AuriStream LM (tokens -> next tokens) |
|
|
lm = AutoModel.from_pretrained( |
|
|
"TuKoResearch/AuriStream1B_librilight_ckpt500k", trust_remote_code=True |
|
|
).to(device).eval() |
|
|
|
|
|
# Load and prep a short prompt (e.g., 3s of audio at 16 kHz) |
|
|
prompt_seconds = 3 |
|
|
wav, sr = torchaudio.load("prompt.wav") |
|
|
if wav.size(0) > 1: |
|
|
wav = wav.mean(dim=0, keepdim=True) |
|
|
if sr != 16_000: |
|
|
wav = torchaudio.transforms.Resample(sr, 16_000)(wav) |
|
|
sr = 16_000 |
|
|
# Slice using an integer number of samples |
|
|
n_samples = int(round(sr * prompt_seconds)) |
|
|
wav = wav[:, :n_samples] |
|
|
|
|
|
# Quantize the prompt audio to get token IDs |
|
|
with torch.no_grad(): |
|
|
prompt_tokens = quantizer(wav.unsqueeze(0).to(device))['input_ids'] # (1, L) |
|
|
|
|
|
# Decide how many future tokens to generate ("roll-out") |
|
|
tokens_per_sec = prompt_tokens.size(1) / float(prompt_seconds) |
|
|
rollout_seconds = 2 |
|
|
rollout_steps = int(round(tokens_per_sec * rollout_seconds)) # K |
|
|
|
|
|
# Generate future tokens |
|
|
with torch.no_grad(): |
|
|
# returns (pred_tokens, pred_logits); temperature/top_k/top_p/seed optional |
|
|
pred_tokens, _ = lm.generate( |
|
|
prompt_tokens, rollout_steps, temp=0.7, top_k=50, top_p=0.95, seed=0 |
|
|
) |
|
|
full_tokens = torch.cat([prompt_tokens, pred_tokens], dim=1) # (1, L+K) |
|
|
``` |
|
|
|
|
|
## Architecture overview |
|
|
|
|
|
<p align="center"> |
|
|
<img src="Fig1_arch_v3.jpg" alt="" width="500"/> |
|
|
</p> |
|
|
|
|
|
Schematic of the WavCoch tokenizer (panel A) and the AuriStream model (panel B). |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@inproceedings{tuckute2025cochleartokens, |
|
|
title = {Representing Speech Through Autoregressive Prediction of Cochlear Tokens}, |
|
|
author = {Greta Tuckute and Klemen Kotar and Evelina Fedorenko and Daniel Yamins}, |
|
|
booktitle = {Interspeech 2025}, |
|
|
year = {2025}, |
|
|
pages = {2180--2184}, |
|
|
doi = {10.21437/Interspeech.2025-2044}, |
|
|
issn = {2958-1796} |
|
|
} |
|
|
``` |