import collections import collections.abc from dataclasses import dataclass from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union, cast import torch import torch.nn as nn import torchaudio.transforms as audio_transforms from torch import Tensor from transformers import GenerationMixin, PreTrainedModel from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( Qwen2_5OmniTextConfig, ) from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( Qwen2_5OmniThinkerTextModel, ) from transformers.utils import can_return_tuple from .configuration_midashenglm import DashengConfig, MiDashengLMConfig _Tuple2 = Union[int, Tuple[int, int], Sequence[int]] def _resolve_tuple2(x: _Tuple2) -> Tuple[int, int]: if isinstance(x, collections.abc.Sequence): assert len(x) == 2, ( f"Expected a sequence of length 2, got {x} with length {len(x)}" ) return cast(Tuple[int, int], tuple(x)) return (x, x) class AudioPatchEmbed(nn.Module): def __init__( self, input_size: _Tuple2 = 64, patch_size: _Tuple2 = 16, patch_stride: _Tuple2 = 16, in_chans: int = 1, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten: bool = False, ): super().__init__() self.input_size = _resolve_tuple2(input_size) self.patch_size = _resolve_tuple2(patch_size) self.patch_stride = _resolve_tuple2(patch_stride) self.grid_size = ( self.input_size[0] // self.patch_stride[0], self.input_size[1] // self.patch_stride[1], ) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_stride, ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) if self.flatten: x = torch.permute( torch.flatten(x, 2, 3), (0, 2, 1) ) # rearrange(x, "b c f t -> b (f t) c") x = self.norm(x) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma class DashengMlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, drop: float = 0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class DashengAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, causal: bool = False, ): super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.causal = causal def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale # if mask is not None: # # Mask is a tensor of shape [B, T, T] # # Different from self.causal == True, the mask might be something like: # # [False, False, True] # # [False, False, True] # # [True, True, True] # # We use -inf to pad here, since if we would pad by any number, the entries at rows only containing # # [True, True, True] would lead to weights such as: [0.33,0.33,0.33], which is not correct if self.causal: mask_value = -torch.finfo(attn.dtype).max i, j = attn.shape[-2:] mask = torch.ones(i, j, device=q.device, dtype=torch.bool).triu(j - i + 1) attn = attn.masked_fill(mask, mask_value) if mask is not None: # mask value as the lowest possible value in fp32 mask_value = torch.finfo(attn.dtype).min # Mask is of shape [1, SRC_LEN] attn_mask = mask[:, None, None, :].expand(B, 1, N, N) # Mask should be of shape # [B,1,Target_len, Source_len] attn = attn.masked_fill(attn_mask, mask_value) attn = attn.softmax(dim=-1) attn = torch.nan_to_num(attn) # Only for the case that a mask with all True entries on a row is passed. # attn = torch.nan_to_num(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class DashengBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, ): super().__init__() self.norm1 = nn.LayerNorm(dim, eps=1e-6) self.attn = DashengAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.norm2 = nn.LayerNorm(dim, eps=1e-6) self.mlp = DashengMlp( in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) # Kwargs usually has a mask parameter that is passed to Attention def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = x + self.ls1(self.attn(self.norm1(x), mask)) x = x + self.ls2(self.mlp(self.norm2(x))) return x class DashengAudioTransformer(PreTrainedModel): config_class = DashengConfig supports_gradient_checkpointing = True def __init__(self, config: DashengConfig): super().__init__(config) self.target_length = config.target_length self.embed_dim = config.embed_dim self.hop_length = config.hop_length self.gradient_checkpointing = False self.front_end = nn.Sequential( audio_transforms.MelSpectrogram( f_min=config.f_min, f_max=config.f_max, center=config.center, win_length=config.win_length, hop_length=config.hop_length, sample_rate=config.sample_rate, n_fft=config.n_fft, n_mels=config.n_mels, ), audio_transforms.AmplitudeToDB(top_db=120), ) self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) self.patch_embed = AudioPatchEmbed( input_size=(config.n_mels, config.target_length), embed_dim=config.embed_dim, in_chans=config.input_channels, patch_size=config.patch_size, flatten=False, patch_stride=config.patch_stride, ) self.time_pos_embed = nn.Parameter( torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02 ) self.freq_pos_embed = nn.Parameter( torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02 ) self.pos_drop = nn.Dropout(p=config.drop_rate) self.blocks = nn.ModuleList( DashengBlock( dim=config.embed_dim, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, init_values=config.init_values, drop=config.drop_rate, attn_drop=config.attn_drop_rate, ) for i in range(config.depth) ) self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) self.post_init() def forward_features( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: t = x.shape[-1] x = x + self.time_pos_embed[:, :, :, :t] x = ( x + self.freq_pos_embed[:, :, :, :] ) # Just to support __getitem__ in posembed x = torch.permute( torch.flatten(x, 2, 3), (0, 2, 1) ) # rearrange(x, "b c f t -> b (f t) c") x = self.pos_drop(x) for block in self.blocks: if self.gradient_checkpointing and self.training: x = self._gradient_checkpointing_func(block, x, mask) else: x = block(x, mask) x = self.norm(x) return x def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor: batch_size = len(lengths) idx = torch.arange(max_length, device=lengths.device) idx = idx.repeat(batch_size).view(batch_size, max_length) mask = (idx < lengths.unsqueeze(-1)).bool() return mask def forward( self, x: torch.Tensor, x_length: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: x = self.front_end(x) target_length_in_patches = self.target_length // 4 x = x.unsqueeze(1) x = torch.permute(x, (0, 2, 1, 3)) x = self.init_bn(x) x = torch.permute(x, (0, 2, 1, 3)) x = self.patch_embed(x) t = x.shape[-1] input_splits = x.split(target_length_in_patches, dim=-1) if x_length is not None: assert len(x_length) == len(x), ( "batchsizes of input x and x_length need to be same" ) assert x_length.ndim == 1, "Lengths are of size (B,)" scaled_lengths = (x_length / (self.hop_length * 4)).long() mask = self._to_mask(max_length=t, lengths=scaled_lengths) split_masks = mask.logical_not().split(target_length_in_patches, dim=-1) else: mask = None split_masks = [None] * len(input_splits) outputs = [] for split_x, split_mask in zip(input_splits, split_masks): forward_kwargs = {} forward_kwargs["mask"] = split_mask split_x = self.forward_features(split_x, **forward_kwargs) outputs.append(split_x) x = torch.cat(outputs, dim=1) return x, mask class AudioProjectorSubsample(nn.Module): def __init__( self, in_dim: int, out_dim: int, downsample_rate=5, dtype: Optional[torch.dtype] = None, ): super().__init__() self.k = downsample_rate self.net = nn.Sequential( nn.Linear(in_dim * self.k, out_dim, dtype=dtype), nn.GELU(), nn.Linear(out_dim, out_dim, dtype=dtype), ) def forward(self, x, mask=None): batch_size, seq_len, dim = x.shape num_frames_to_discard = seq_len % self.k if num_frames_to_discard > 0: x = x[:, :-num_frames_to_discard, :] if mask is not None: mask = mask[:, :-num_frames_to_discard] if mask is None: mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device) x = x.reshape( batch_size, -1, self.k * dim ) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k) x = self.net(x) mask = mask.reshape( batch_size, -1, self.k ) # rearrange(mask, "b (s k) -> b s k", k=self.k) mask = mask.any(dim=-1).long() return x, mask @dataclass class Qwen25OmniTextModelOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin): config_class = Qwen2_5OmniTextConfig _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True def __init__(self, config: Qwen2_5OmniTextConfig): super().__init__(config) self.model = Qwen2_5OmniThinkerTextModel._from_config(config) self.lm_head = nn.Linear( config.hidden_size, config.vocab_size, bias=False, ) self.post_init() @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, Qwen25OmniTextModelOutput]: if attention_mask is not None and position_ids is None: position_ids = ( attention_mask.long() .cumsum(dim=-1) .masked_fill_(attention_mask == 0, 1) - 1 ) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=True, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = ( self.loss_function( logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) if labels is not None else None ) return Qwen25OmniTextModelOutput( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class MiDashengLMModel(PreTrainedModel): config_class = MiDashengLMConfig _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True supports_gradient_checkpointing = True def __init__(self, config: MiDashengLMConfig): super().__init__(config) self.audio_token_id = config.audio_token_id self.audio_encoder = DashengAudioTransformer._from_config( config.audio_encoder_config, ) self.audio_projector = AudioProjectorSubsample( self.audio_encoder.embed_dim, config.text_config.hidden_size, config.subsample_factor, ) self.decoder = Qwen25OmniThinkerTextOnlyDecoder._from_config( config.text_config, attn_implementation=config._attn_implementation, ) self.post_init() def get_input_embeddings(self): return self.decoder.model.embed_tokens def get_output_embeddings(self): return self.decoder.lm_head def _forward_audio_encoder( self, audios: torch.Tensor, audio_length: Optional[Iterable[int]], ) -> torch.Tensor: encoder_out, encoder_atts = self.audio_encoder(audios, audio_length) # audio projector encoder_out, encoder_atts = self.audio_projector(encoder_out, encoder_atts) return encoder_out def _prepare_inputs_embeds( self, input_ids: Optional[torch.Tensor], input_values: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor], audio_length: Optional[Iterable[int]] = None, ) -> torch.Tensor: if input_ids is not None: if inputs_embeds is not None: raise ValueError( "Both `inputs_embeds` and `input_ids` are passed. Please pass only one of them." ) inputs_embeds = cast( torch.Tensor, self.decoder.model.embed_tokens(input_ids) ) if input_values is not None: if self.audio_token_id is None: raise ValueError( "Audio input is provided, but `audio_token_id` is not configured." ) audio_embeddings = self._forward_audio_encoder( input_values, audio_length=audio_length, ).to(inputs_embeds.dtype) audio_mask = (input_ids == self.audio_token_id).flatten() diff = torch.diff( audio_mask.long(), prepend=torch.zeros( (1,), dtype=torch.long, device=audio_mask.device, ), ) audio_span_starts = (diff == 1).nonzero() audio_span_ends = (diff == -1).nonzero() embeds_view = inputs_embeds.view(-1, inputs_embeds.shape[-1]) for span_start, span_end, audio in zip( audio_span_starts, audio_span_ends, audio_embeddings, strict=True, ): embeds_view[span_start:span_end] = audio[: span_end - span_start] else: if inputs_embeds is None: raise ValueError( "Either `input_ids` or `inputs_embeds` must be passed." ) if input_values is not None: raise ValueError( "Cannot pass `input_values` when `inputs_embeds` is provided." ) return inputs_embeds def forward( self, input_ids: Optional[Tensor] = None, input_values: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, audio_length: Optional[Iterable[int]] = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs: Any, ): inputs_embeds = self._prepare_inputs_embeds( input_ids=input_ids, input_values=input_values, inputs_embeds=inputs_embeds, audio_length=audio_length, ) return self.decoder( input_ids=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels, **kwargs, ) def generate( self, input_ids: Optional[Tensor] = None, input_values: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, audio_length: Optional[Iterable[int]] = None, **kwargs, ): inputs_embeds = self._prepare_inputs_embeds( input_ids=input_ids, input_values=input_values, inputs_embeds=inputs_embeds, audio_length=audio_length, ) return self.decoder.generate( inputs_embeds=inputs_embeds, generation_config=kwargs.pop("generation_config", self.generation_config), **kwargs, )