gemma3n-TextAudio / README.md
johaness14's picture
Update README.md
bb9c1a3 verified
metadata
library_name: transformers
tags:
  - gemma3n
  - text-generation
  - asr
  - pruned
  - audio-only
  - no-vision
base_model: google/gemma-3n-E2B-it
license: gemma
model-index:
  - name: gemma3n-text-audio
    results: []

EXPERIMENTAL

Package

pip install --upgrade transformers hf-xet timm librosa

Inference Code

import torch, librosa, os, numpy as np
from transformers import (
    AutoTokenizer, AutoProcessor, AutoImageProcessor,
    Gemma3nAudioFeatureExtractor, Gemma3nProcessor,
    Gemma3nForConditionalGeneration
)

def to_device_preserve_dtypes(batch, device, model_dtype):
    fixed = {}
    for k, v in batch.items():
        if not isinstance(v, torch.Tensor):
            fixed[k] = v
            continue
        if k == "input_features":
            fixed[k] = v.to(device=device, dtype=model_dtype)
        elif k in ("input_ids", "position_ids", "token_type_ids"):
            fixed[k] = v.to(device=device)  # keep int dtype
        elif "mask" in k:
            fixed[k] = v.to(device=device).to(torch.bool)
        else:
            fixed[k] = v.to(device=device)
    return fixed

def _load_mono(audio_path: str, target_sr: int):
    wav, sr = librosa.load(audio_path, sr=target_sr, mono=True)
    return wav.astype(np.float32), target_sr

def transcribe(
    audio_path: str,
    model_id: str = "<username>/<repo-name>",
    prompt: str = "Transcribe this audio exactly.",
    max_new_tokens: int = 128,
    token: str = "YOUR_TOKEN_HF"
):
    tok = AutoTokenizer.from_pretrained(model_id)
    fe  = Gemma3nAudioFeatureExtractor.from_pretrained(model_id)

    chat_tmpl = getattr(tok, "chat_template", None)
    if not chat_tmpl:
        base = AutoProcessor.from_pretrained("google/gemma-3n-E2B-it")
        if getattr(base, "chat_template", None):
            chat_tmpl = base.chat_template

    img_proc = AutoImageProcessor.from_pretrained("google/gemma-3n-E2B-it", 
                                                  token=token
    )

    processor = Gemma3nProcessor(
        tokenizer=tok,
        feature_extractor=fe,
        image_processor=img_proc,
        chat_template=chat_tmpl,
    )

    model = Gemma3nForConditionalGeneration.from_pretrained(
        model_id, dtype="auto", device_map="auto",
    ).eval()

    if hasattr(model, "model") and hasattr(model.model, "vision_tower"):
        try:
            model.model.vision_tower = None
        except Exception:
            pass
    if hasattr(model, "config"):
        for k in ("vision_config", "vision_tower", "image_size"):
            if hasattr(model.config, k):
                setattr(model.config, k, None)

    target_sr = getattr(fe, "sampling_rate", 16000)
    wav, _ = _load_mono(audio_path, target_sr)

    messages = [
        {"role": "system", "content": [{"type": "text", "text": prompt}]},
        {"role": "user",   "content": [{"type": "audio", "audio": wav}]}
    ]
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True,
        tokenize=True, return_dict=True, return_tensors="pt",
    )

    inputs = to_device_preserve_dtypes(
        inputs, device=model.device, model_dtype=getattr(model, "dtype", torch.float16)
    )
    input_len = inputs["input_ids"].shape[-1]

    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens)
        out = out[0][input_len:]

    text = processor.decode(out, skip_special_tokens=True)

    return text

# Contoh pakai:
print(transcribe("your_audio.wav", model_id="<username>/<repo-name>", token="YOUR_TOKEN_KEY"))