import numpy as np import torch from typing import List, Dict, Any, Union, Optional from transformers.processing_utils import ProcessorMixin from transformers import WhisperFeatureExtractor, AutoTokenizer, AutoProcessor from configs import VLFMConfig, build_tokenizer class VLFMProcessor(ProcessorMixin): attributes = ["feature_extractor", "tokenizer"] feature_extractor_class = "WhisperFeatureExtractor" tokenizer_class = "AutoTokenizer" def __init__( self, feature_extractor: WhisperFeatureExtractor, tokenizer: AutoTokenizer, config: VLFMConfig, ): super().__init__(feature_extractor=feature_extractor, tokenizer=tokenizer) _, self.audio_token_id = build_tokenizer( config.text_model_id, config.tokenizer_padding_side ) if self.audio_token_id == tokenizer.unk_token_id: raise ValueError( "Audio placeholder token is . " "Add a real special token (e.g. <|audio|>) to the tokenizer vocab." ) #print(f"audio_placeholder_token_id: {self.audio_token_id}") if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = config.tokenizer_padding_side self.ds_rate = int(config.ds_rate) self.stack_factor = int(config.stack_factor) self.max_seconds = float(config.max_seconds) self._marker = "" def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, Any]: if isinstance(text, str): text = [text] def _contains_audio_placeholder(s: str) -> bool: ids = self.tokenizer(s, add_special_tokens=False)["input_ids"] return self.audio_token_id in ids if any(_contains_audio_placeholder(t) for t in text): raise ValueError( "Audio placeholder token detected in raw text. " "Use `apply_chat_template` with {'type': 'audio', 'array': ...} instead." ) enc = self.tokenizer(text, **kwargs) return enc def _validate_audio(self, x: np.ndarray) -> np.ndarray: if not isinstance(x, np.ndarray): x = np.asarray(x) if x.ndim == 2: x = x.mean(axis=0) if x.shape[0] < x.shape[1] else x.mean(axis=1) elif x.ndim != 1: raise ValueError(f"Expected 1-D mono waveform, got shape={x.shape}") if x.size == 0 or not np.isfinite(x).all(): raise ValueError("Audio is empty or contains NaNs/Infs.") return x.astype(np.float32, copy=False) def apply_chat_template( self, conversation: List[Dict], *, tokenize: bool = False, add_generation_prompt: bool = False, padding: bool = False, truncation: bool = True, max_length: int = 4096, sampling_rate: Optional[int] = 16_000, return_tensors: str = "pt", **kwargs, ) -> Dict[str, Any]: """ conversation: list of turns, where each turn is: {"role": "user" | "assistant" | "system", "content": str | List[{"type": "text"|"audio", ...}]} Exactly one audio span is supported per conversation. """ if not isinstance(conversation, list) or not conversation: raise ValueError("Conversation must be a non-empty list of turns.") text_conv: List[Dict[str, str]] = [] audio_array: Optional[np.ndarray] = None for turn in conversation: role = turn.get("role", None) content = turn.get("content", "") if not isinstance(role, str): raise ValueError("Each turn must have a string 'role'.") if isinstance(content, str): text_conv.append({"role": role, "content": content}) continue buf: List[str] = [] for item in content: t = item.get("type", None) if t == "text": buf.append(item.get("text", "")) elif t == "audio": if audio_array is not None: raise ValueError( "VLFMProcessor supports exactly one audio span per conversation." ) arr = item.get("array", None) if arr is None: raise ValueError("Audio item missing 'array'.") audio_array = self._validate_audio(arr) buf.append(self._marker) else: raise ValueError(f"Unsupported content type: {t}") text_conv.append({"role": role, "content": "".join(buf)}) if audio_array is None: raise ValueError("No audio found in conversation (exactly one audio span required).") prompt = self.tokenizer.apply_chat_template( text_conv, tokenize=False, add_generation_prompt=add_generation_prompt ) #print(f"Output after apply chat: {prompt}") sr = int(sampling_rate or self.feature_extractor.sampling_rate) max_samples = int(self.max_seconds * sr) if audio_array.shape[0] > max_samples: audio_array = audio_array[:max_samples] hop = int(self.feature_extractor.hop_length) if audio_array.shape[0] < 2 * hop: audio_array = np.pad(audio_array, (0, 2 * hop - audio_array.shape[0])) feat = self.feature_extractor( [audio_array], sampling_rate=sr, return_attention_mask=True, return_tensors="pt", **kwargs, ) feats = feat.input_features attn = feat.attention_mask mel_len = int(attn.sum(-1).item()) scale = self.ds_rate * self.stack_factor audio_token_len = int(np.ceil(mel_len / float(scale))) if audio_token_len <= 0: raise ValueError(f"Computed non-positive audio_token_len={audio_token_len} (mel_len={mel_len}).") left, sep, right = prompt.partition(self._marker) if sep == "": raise ValueError("Internal error: marker missing.") left_ids = self.tokenizer(left, add_special_tokens=False)["input_ids"] right_ids = self.tokenizer(right, add_special_tokens=False)["input_ids"] input_ids = left_ids + [self.audio_token_id] * audio_token_len + right_ids attention_mask = [1] * len(input_ids) audio_token_start_idx = len(left_ids) out = { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), # Audio side: "input_features": feats[0].clone(), "audio_lens": torch.tensor([mel_len], dtype=torch.long), # Splicing helpers: "audio_token_start_idx": torch.tensor([audio_token_start_idx], dtype=torch.long), # [1] "audio_token_len": torch.tensor([audio_token_len], dtype=torch.long), # [1] "audio_batch_size": torch.tensor([1], dtype=torch.long), # [1] "audio_is_continuation": torch.tensor([False]), # [1] } if tokenize: tok = self.tokenizer( prompt, padding=padding, truncation=truncation, max_length=max_length, return_tensors="pt", ) return out VLFMProcessor.register_for_auto_class() AutoProcessor.register(VLFMConfig, VLFMProcessor)