AuriStream-1B
AuriStream is a biologically-inspired, GPT-style autoregressive Transformer trained to predict cochlear tokens - discrete codes produced by a companion “WavCoch” tokenizer over long speech contexts (through transofmration imitation). Auristream utilizes a long context window of (~20 s, ~4096 tokens) and is trained on LibriLight (~60k h) for ~500k steps. It learns rich, time‑aligned representations (useful for linear probing) and can roll out future tokens to generate speech continuations. Inputs are token IDs; use it with a WavCoch quantizer for audio->tokens and with the built in vocoder for tokens->audio.
Installation
pip install -U torch torchaudio transformers
This model uses custom code; when loading from Hugging Face, pass trust_remote_code=True
.
Use Case 1) get hidden‑state embeddings from a WAV
import torch, torchaudio
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1) Load the WavCoch tokenizer (audio -> token IDs)
quantizer = AutoModel.from_pretrained(
"TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()
# 2) Load the AuriStream LM (tokens -> hidden states / next-token preds)
lm = AutoModel.from_pretrained(
"TuKoResearch/AuriStream1B_40Pred_librilight_500k", trust_remote_code=True
).to(device).eval()
# 3) Read an audio file (mono, 16 kHz recommended)
wav, sr = torchaudio.load("sample.wav")
if wav.size(0) > 1: # stereo -> mono
wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
sr = 16_000
# 4) Quantize to cochlear token IDs
with torch.no_grad():
# quantizer.quantize expects (B, T); returns LongTensor (B, L)
token_ids = quantizer.quantize(wav.unsqueeze(0).to(device)) # (1, L)
# 5) Forward pass with hidden states
with torch.no_grad():
out = lm(token_ids, output_hidden_states=True)
last_layer = out["hidden_states"][-1] # (1, T, D)
clip_embedding = last_layer.mean(dim=1) # time mean-pool -> (1, D)
print("Pooled embedding shape:", clip_embedding.shape)
Notes
output_hidden_states=True
returns all layers; choose a layer or pool over time.- For word/phone segments, slice the time axis before pooling.
Use Case 2) generate a speech continuation (token rollout)
import torch, torchaudio
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
# WavCoch tokenizer (audio->tokens, tokens->cochleagram->audio)
quantizer = AutoModel.from_pretrained(
"TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()
# AuriStream LM (tokens->next tokens)
lm = AutoModel.from_pretrained(
"TuKoResearch/AuriStream1B_40Pred_librilight_500k", trust_remote_code=True
).to(device).eval()
# Load & prep a short prompt (e.g., 3s of audio at 16 kHz)
wav, sr = torchaudio.load("prompt.wav")
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
sr = 16_000
prompt_seconds = 3
wav = wav[:, : sr * prompt_seconds]
# Quantize prompt to token IDs
with torch.no_grad():
prompt_tokens = quantizer.quantize(wav.unsqueeze(0).to(device)) # (1, L)
# Decide how many future tokens to generate
tokens_per_sec = prompt_tokens.size(1) / float(prompt_seconds)
rollout_seconds = 3
rollout_steps = int(round(tokens_per_sec * rollout_seconds))
# Roll out future tokens
with torch.no_grad():
# returns (pred_tokens, pred_logits); temperature/top_k/top_p/seed optional
pred_tokens, _ = lm.generate(
prompt_tokens, rollout_steps, temp=0.7, top_k=50, top_p=0.95, seed=0
)
full_tokens = torch.cat([prompt_tokens, pred_tokens], dim=1) # (1, L+K)
Citation
If you use this model, please cite:
@misc{tuckute2025cochleartokens,
title = {Representing Speech Through Autoregressive Prediction of Cochlear Tokens},
author = {Tuckute, Greta and Kotar, Klemen and Fedorenko, Evelina and Yamins, Daniel L. K.},
year = {2025},
eprint = {2508.11598},
archivePrefix = {arXiv},
url = {https://arxiv.org/abs/2508.11598}
}
- Downloads last month
- 28