EXPERIMENTAL
Package
pip install --upgrade transformers hf-xet timm librosa
Inference Code
import torch, librosa, os, numpy as np
from transformers import (
AutoTokenizer, AutoProcessor, AutoImageProcessor,
Gemma3nAudioFeatureExtractor, Gemma3nProcessor,
Gemma3nForConditionalGeneration
)
def to_device_preserve_dtypes(batch, device, model_dtype):
fixed = {}
for k, v in batch.items():
if not isinstance(v, torch.Tensor):
fixed[k] = v
continue
if k == "input_features":
fixed[k] = v.to(device=device, dtype=model_dtype)
elif k in ("input_ids", "position_ids", "token_type_ids"):
fixed[k] = v.to(device=device)
elif "mask" in k:
fixed[k] = v.to(device=device).to(torch.bool)
else:
fixed[k] = v.to(device=device)
return fixed
def _load_mono(audio_path: str, target_sr: int):
wav, sr = librosa.load(audio_path, sr=target_sr, mono=True)
return wav.astype(np.float32), target_sr
def transcribe(
audio_path: str,
model_id: str = "<username>/<repo-name>",
prompt: str = "Transcribe this audio exactly.",
max_new_tokens: int = 128,
token: str = "YOUR_TOKEN_HF"
):
tok = AutoTokenizer.from_pretrained(model_id)
fe = Gemma3nAudioFeatureExtractor.from_pretrained(model_id)
chat_tmpl = getattr(tok, "chat_template", None)
if not chat_tmpl:
base = AutoProcessor.from_pretrained("google/gemma-3n-E2B-it")
if getattr(base, "chat_template", None):
chat_tmpl = base.chat_template
img_proc = AutoImageProcessor.from_pretrained("google/gemma-3n-E2B-it",
token=token
)
processor = Gemma3nProcessor(
tokenizer=tok,
feature_extractor=fe,
image_processor=img_proc,
chat_template=chat_tmpl,
)
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, dtype="auto", device_map="auto",
).eval()
if hasattr(model, "model") and hasattr(model.model, "vision_tower"):
try:
model.model.vision_tower = None
except Exception:
pass
if hasattr(model, "config"):
for k in ("vision_config", "vision_tower", "image_size"):
if hasattr(model.config, k):
setattr(model.config, k, None)
target_sr = getattr(fe, "sampling_rate", 16000)
wav, _ = _load_mono(audio_path, target_sr)
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{"role": "user", "content": [{"type": "audio", "audio": wav}]}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True,
tokenize=True, return_dict=True, return_tensors="pt",
)
inputs = to_device_preserve_dtypes(
inputs, device=model.device, model_dtype=getattr(model, "dtype", torch.float16)
)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens)
out = out[0][input_len:]
text = processor.decode(out, skip_special_tokens=True)
return text
print(transcribe("your_audio.wav", model_id="<username>/<repo-name>", token="YOUR_TOKEN_KEY"))