AuriStream-1B

AuriStream is a biologically-inspired, GPT-style autoregressive Transformer trained to predict cochlear tokens - discrete codes produced by a companion “WavCoch” tokenizer over long speech contexts (through transofmration imitation). Auristream utilizes a long context window of (~20 s, ~4096 tokens) and is trained on LibriLight (~60k h) for ~500k steps. It learns rich, time‑aligned representations (useful for linear probing) and can roll out future tokens to generate speech continuations. Inputs are token IDs; use it with a WavCoch quantizer for audio->tokens and with the built in vocoder for tokens->audio.


Installation

pip install -U torch torchaudio transformers

This model uses custom code; when loading from Hugging Face, pass trust_remote_code=True.


Use Case 1) get hidden‑state embeddings from a WAV

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Load the WavCoch tokenizer (audio -> token IDs)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# 2) Load the AuriStream LM (tokens -> hidden states / next-token preds)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_40Pred_librilight_500k", trust_remote_code=True
).to(device).eval()

# 3) Read an audio file (mono, 16 kHz recommended)
wav, sr = torchaudio.load("sample.wav")
if wav.size(0) > 1:  # stereo -> mono
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000

# 4) Quantize to cochlear token IDs
with torch.no_grad():
    # quantizer.quantize expects (B, T); returns LongTensor (B, L)
    token_ids = quantizer.quantize(wav.unsqueeze(0).to(device))  # (1, L)

# 5) Forward pass with hidden states
with torch.no_grad():
    out = lm(token_ids, output_hidden_states=True)
    last_layer = out["hidden_states"][-1]   # (1, T, D)
    clip_embedding = last_layer.mean(dim=1)  # time mean-pool -> (1, D)

print("Pooled embedding shape:", clip_embedding.shape)

Notes

  • output_hidden_states=True returns all layers; choose a layer or pool over time.
  • For word/phone segments, slice the time axis before pooling.

Use Case 2) generate a speech continuation (token rollout)

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# WavCoch tokenizer (audio->tokens, tokens->cochleagram->audio)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# AuriStream LM (tokens->next tokens)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_40Pred_librilight_500k", trust_remote_code=True
).to(device).eval()

# Load & prep a short prompt (e.g., 3s of audio at 16 kHz)
wav, sr = torchaudio.load("prompt.wav")
if wav.size(0) > 1:
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000
prompt_seconds = 3
wav = wav[:, : sr * prompt_seconds]

# Quantize prompt to token IDs
with torch.no_grad():
    prompt_tokens = quantizer.quantize(wav.unsqueeze(0).to(device))  # (1, L)

# Decide how many future tokens to generate
tokens_per_sec = prompt_tokens.size(1) / float(prompt_seconds)
rollout_seconds = 3
rollout_steps = int(round(tokens_per_sec * rollout_seconds))

# Roll out future tokens
with torch.no_grad():
    # returns (pred_tokens, pred_logits); temperature/top_k/top_p/seed optional
    pred_tokens, _ = lm.generate(
        prompt_tokens, rollout_steps, temp=0.7, top_k=50, top_p=0.95, seed=0
    )
    full_tokens = torch.cat([prompt_tokens, pred_tokens], dim=1)  # (1, L+K)

Citation

If you use this model, please cite:

@misc{tuckute2025cochleartokens,
  title = {Representing Speech Through Autoregressive Prediction of Cochlear Tokens},
  author = {Tuckute, Greta and Kotar, Klemen and Fedorenko, Evelina and Yamins, Daniel L. K.},
  year = {2025},
  eprint = {2508.11598},
  archivePrefix = {arXiv},
  url = {https://arxiv.org/abs/2508.11598}
}
Downloads last month
28
Safetensors
Model size
1.38B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support