Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| from functools import partial | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| from llava.model.encoders.base import BaseEncoder | |
| __all__ = ["BasicSoundEncoder"] | |
| class BasicSoundEncoder(BaseEncoder): | |
| def __init__( | |
| self, | |
| parent: torch.nn.Module, | |
| start_tokens: Optional[str] = None, | |
| end_tokens: Optional[str] = "\n", | |
| ) -> None: | |
| super().__init__(parent) | |
| self.start_tokens = start_tokens | |
| self.end_tokens = end_tokens | |
| def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: | |
| if tokens is None: | |
| return None | |
| token_ids = self.parent.tokenizer(tokens).input_ids | |
| token_ids = torch.tensor(token_ids, device=self.parent.device) | |
| return self.parent.llm.model.embed_tokens(token_ids) | |
| def _process_features( | |
| self, | |
| features: torch.Tensor, | |
| start_token_embeds: Optional[torch.Tensor], | |
| end_token_embeds: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| features = features.to(self.parent.device) | |
| if start_token_embeds is not None: | |
| features = torch.cat([start_token_embeds, features], dim=0) | |
| if end_token_embeds is not None: | |
| features = torch.cat([features, end_token_embeds], dim=0) | |
| return features | |
| def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], masks: Dict[str, Any]) -> List[torch.Tensor]: | |
| sounds = torch.stack(sounds, dim=0) | |
| masks = torch.stack(masks, dim=0) | |
| features = self.parent.encode_sound(sounds, masks) | |
| process_features = partial( | |
| self._process_features, | |
| start_token_embeds=self.embed_tokens(self.start_tokens), | |
| end_token_embeds=self.embed_tokens(self.end_tokens), | |
| ) | |
| return [process_features(f) for f in features] | |