In [9]:
%%capture
# Core libraries
!pip install torch torchaudio transformers pydub numpy pyctcdecode
# If you need mp3 input support
!sudo apt-get update -qq
!sudo apt-get install -y ffmpeg
# For KenLM ARPA/bin support
!pip install https://github.com/kpu/kenlm/archive/master.zip

In [7]:
MODEL_PATH     = "/content/drive/MyDrive/artifacts/models/hf/hf_tgt/tigre-asr-Wav2Vec2Bert"              # model and processor path
PROCESSOR_PATH = MODEL_PATH
AUDIO_FILE     = MODEL_PATH+"/sample.wav"
OUTPUT_TXT     = None                                   # e.g., "/path/to/out.txt" or None to just print
# KenLM + lexicon (optional but recommended for beam search)
KENLM_ARPA     = MODEL_PATH+"/lm.arpa"                  # set to None to decode WITHOUT LM
LEXICON_TXT    = MODEL_PATH+"/lexicon.txt"              # used to load unigrams; set to None if not available

In [2]:
import warnings
import logging

# Silence all Python warnings
warnings.filterwarnings("ignore")
# Silence pyctcdecode logger
logging.getLogger("pyctcdecode").setLevel(logging.ERROR)
# Silence torchaudio warnings (optionally all)
logging.getLogger("torchaudio").setLevel(logging.ERROR)

In [3]:
# Audio / chunking
TARGET_SR      = 16000
CHUNK_SEC      = 5        # chunk length in seconds
OVERLAP_SEC    = 0        # overlap between chunks in seconds (0 for minimal code)
# Beam search params
BEAM_WIDTH     = 150
LM_ALPHA       = 0.5
LM_BETA        = 1.0

In [8]:
import os
import torch
import numpy as np
import torchaudio
from typing import List, Optional

# Use pydub for robust mp3 handling
from pydub import AudioSegment

from transformers import Wav2Vec2BertForCTC, Wav2Vec2BertProcessor

# Optional LM decoding
try:
    from pyctcdecode import build_ctcdecoder
    _HAS_PYCTC = True
except Exception:
    _HAS_PYCTC = False

# Pick device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def _load_audio(path: str, target_sr: int = 16000) -> torch.Tensor:
    """Load WAV or MP3 to mono float32 tensor [1, T] at target_sr."""
    ext = os.path.splitext(path)[1].lower()
    if ext == ".mp3":
        audio = AudioSegment.from_file(path, format="mp3")
        audio = audio.set_channels(1).set_frame_rate(target_sr)
        samples = np.array(audio.get_array_of_samples()).astype(np.float32)
        # pydub gives int PCM range; normalize if needed (assume 16-bit)
        if samples.dtype != np.float32:
            samples = samples.astype(np.float32)
        # If sample_width==2 (16-bit), divide by 32768
        if audio.sample_width == 2:
            samples /= 32768.0
        return torch.from_numpy(samples).unsqueeze(0)
    else:
        wav, sr = torchaudio.load(path)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)  # stereo -> mono
        if sr != target_sr:
            wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
        # ensure float32 in [-1,1]
        if wav.dtype != torch.float32:
            wav = wav.to(torch.float32)
        return wav

def _chunks(wave: torch.Tensor, sr: int, chunk_sec: int, overlap_sec: int):
    """Yield possibly-overlapping chunks [1, T_chunk]."""
    chunk = int(chunk_sec * sr)
    step  = max(1, chunk - int(overlap_sec * sr))
    T     = wave.size(-1)
    for start in range(0, T, step):
        end = min(start + chunk, T)
        yield wave[:, start:end]
        if end >= T:
            break

def _load_unigrams(lexicon_path: Optional[str]) -> List[str]:
    """Read first token per line from lexicon into a unigram list."""
    if not lexicon_path or not os.path.exists(lexicon_path):
        return []
    words = set()
    with open(lexicon_path, "r", encoding="utf-8") as f:
        for line in f:
            w = line.strip().split()
            if w:
                words.add(w[0])
    return sorted(words)

def _build_decoder(model, processor):
    """Build a pyctcdecode decoder from model vocab + KenLM (if configured)."""
    # Build vocab (id -> token)
    vocab_size = model.lm_head.out_features
    labels = []
    for i in range(vocab_size):
        tok = processor.tokenizer.convert_ids_to_tokens([i])[0]
        # remove common BPE markers
        tok = tok.lstrip("Ġ").lstrip("▁")
        labels.append(tok)

    # No LM? Use labels only; with LM? also pass unigrams + alpha/beta
    if not _HAS_PYCTC:
        return None

    if KENLM_ARPA and os.path.exists(KENLM_ARPA):
        unigrams = _load_unigrams(LEXICON_TXT)
        return build_ctcdecoder(
            labels=labels,
            kenlm_model_path=KENLM_ARPA,
            unigrams=unigrams if unigrams else None,
            alpha=LM_ALPHA,
            beta=LM_BETA
        )
    else:
        # Fallback to lexicon-less decoder (greedy-ish beam without LM)
        return build_ctcdecoder(labels=labels)

def _postprocess(text: str) -> str:
    """Light cleanup: strip special markers, collapse dup words, ensure end punctuation."""
    text = text.replace("<|", "").replace("|>", "").replace("<>", "").strip()
    words, cleaned = text.split(), []
    for w in words:
        if not cleaned or cleaned[-1] != w:
            cleaned.append(w)
    out = " ".join(cleaned).strip()
    if out and out[-1] not in ".!?":
        out += "."
    return out

def transcribe_one_file() -> str:
    # Load model + processor
    model = Wav2Vec2BertForCTC.from_pretrained(MODEL_PATH).to(device).eval()
    processor = Wav2Vec2BertProcessor.from_pretrained(PROCESSOR_PATH)

    # Optional decoder
    decoder = _build_decoder(model, processor)

    # Load audio
    wav = _load_audio(AUDIO_FILE, TARGET_SR)

    # Transcribe by chunks
    pieces = []
    for chunk in _chunks(wav, TARGET_SR, CHUNK_SEC, OVERLAP_SEC):
        # processor for Wav2Vec2Bert expects raw audio -> input_features
        inputs = processor(chunk.squeeze().numpy(), sampling_rate=TARGET_SR, return_tensors="pt").to(device)
        with torch.no_grad():
            logits = model(input_features=inputs.input_features).logits  # [1, T, V]
        logp = logits[0].cpu().numpy()

        if decoder is not None:
            hypo = decoder.decode(logp, beam_width=BEAM_WIDTH)
        else:
            # Greedy fallback if pyctcdecode not available
            ids = logp.argmax(axis=-1)
            tokens = processor.tokenizer.convert_ids_to_tokens(ids.tolist())
            hypo = "".join(tokens)

        if hypo.strip():
            pieces.append(hypo.strip())

        # cleanup per chunk
        del inputs, logits, logp

    text = _postprocess(" ".join(pieces))
    return text

if __name__ == "__main__":
    out = transcribe_one_file()
    if OUTPUT_TXT:
        os.makedirs(os.path.dirname(OUTPUT_TXT), exist_ok=True)
        with open(OUTPUT_TXT, "w", encoding="utf-8") as f:
            f.write(out + "\n")
    print(out)


ሕርጊጎ ምነ ምን ዘበን አትራክ እንዴ አንበተት እብ መረባቤዐ ግሩም ለትሐሌ መዲነት ተ.
