You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

EXPERIMENTAL

Package

pip install --upgrade transformers hf-xet timm librosa

Inference Code

import torch, librosa, os, numpy as np
from transformers import (
    AutoTokenizer, AutoProcessor, AutoImageProcessor,
    Gemma3nAudioFeatureExtractor, Gemma3nProcessor,
    Gemma3nForConditionalGeneration
)

def to_device_preserve_dtypes(batch, device, model_dtype):
    fixed = {}
    for k, v in batch.items():
        if not isinstance(v, torch.Tensor):
            fixed[k] = v
            continue
        if k == "input_features":
            fixed[k] = v.to(device=device, dtype=model_dtype)
        elif k in ("input_ids", "position_ids", "token_type_ids"):
            fixed[k] = v.to(device=device)  # keep int dtype
        elif "mask" in k:
            fixed[k] = v.to(device=device).to(torch.bool)
        else:
            fixed[k] = v.to(device=device)
    return fixed

def _load_mono(audio_path: str, target_sr: int):
    wav, sr = librosa.load(audio_path, sr=target_sr, mono=True)
    return wav.astype(np.float32), target_sr

def transcribe(
    audio_path: str,
    model_id: str = "<username>/<repo-name>",
    prompt: str = "Transcribe this audio exactly.",
    max_new_tokens: int = 128,
    token: str = "YOUR_TOKEN_HF"
):
    tok = AutoTokenizer.from_pretrained(model_id)
    fe  = Gemma3nAudioFeatureExtractor.from_pretrained(model_id)

    chat_tmpl = getattr(tok, "chat_template", None)
    if not chat_tmpl:
        base = AutoProcessor.from_pretrained("google/gemma-3n-E2B-it")
        if getattr(base, "chat_template", None):
            chat_tmpl = base.chat_template

    img_proc = AutoImageProcessor.from_pretrained("google/gemma-3n-E2B-it", 
                                                  token=token
    )

    processor = Gemma3nProcessor(
        tokenizer=tok,
        feature_extractor=fe,
        image_processor=img_proc,
        chat_template=chat_tmpl,
    )

    model = Gemma3nForConditionalGeneration.from_pretrained(
        model_id, dtype="auto", device_map="auto",
    ).eval()

    if hasattr(model, "model") and hasattr(model.model, "vision_tower"):
        try:
            model.model.vision_tower = None
        except Exception:
            pass
    if hasattr(model, "config"):
        for k in ("vision_config", "vision_tower", "image_size"):
            if hasattr(model.config, k):
                setattr(model.config, k, None)

    target_sr = getattr(fe, "sampling_rate", 16000)
    wav, _ = _load_mono(audio_path, target_sr)

    messages = [
        {"role": "system", "content": [{"type": "text", "text": prompt}]},
        {"role": "user",   "content": [{"type": "audio", "audio": wav}]}
    ]
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True,
        tokenize=True, return_dict=True, return_tensors="pt",
    )

    inputs = to_device_preserve_dtypes(
        inputs, device=model.device, model_dtype=getattr(model, "dtype", torch.float16)
    )
    input_len = inputs["input_ids"].shape[-1]

    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens)
        out = out[0][input_len:]

    text = processor.decode(out, skip_special_tokens=True)

    return text

# Contoh pakai:
print(transcribe("your_audio.wav", model_id="<username>/<repo-name>", token="YOUR_TOKEN_KEY"))
Downloads last month
-
Safetensors
Model size
6B params
Tensor type
F16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for johaness14/gemma3n-TextAudio

Finetuned
(26)
this model