Spaces:
Runtime error
Runtime error
| """ | |
| Projection layers for multimodal fusion | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional | |
| class VisionProjector(nn.Module): | |
| """Projects vision features to language model embedding space""" | |
| def __init__( | |
| self, | |
| vision_dim: int, | |
| language_dim: int, | |
| hidden_dim: Optional[int] = None, | |
| dropout: float = 0.1 | |
| ): | |
| super().__init__() | |
| hidden_dim = hidden_dim or language_dim | |
| self.projector = nn.Sequential( | |
| nn.Linear(vision_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, language_dim), | |
| nn.LayerNorm(language_dim) | |
| ) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Initialize projection weights""" | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| def forward(self, vision_features: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| vision_features: [batch_size, vision_dim] | |
| Returns: | |
| projected_features: [batch_size, language_dim] | |
| """ | |
| return self.projector(vision_features) | |
| class AudioProjector(nn.Module): | |
| """Projects audio features to language model embedding space""" | |
| def __init__( | |
| self, | |
| audio_dim: int, | |
| language_dim: int, | |
| dropout: float = 0.1 | |
| ): | |
| super().__init__() | |
| self.projector = nn.Sequential( | |
| nn.Linear(audio_dim, language_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.LayerNorm(language_dim) | |
| ) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Initialize projection weights""" | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| def forward(self, audio_features: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| audio_features: [batch_size, audio_dim] | |
| Returns: | |
| projected_features: [batch_size, language_dim] | |
| """ | |
| return self.projector(audio_features) | |