midashenglm-7b / modeling_midashenglm.py
zhoukz's picture
Upload folder using huggingface_hub
5e3b785 verified
import collections
import collections.abc
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union, cast
import torch
import torch.nn as nn
import torchaudio.transforms as audio_transforms
from torch import Tensor
from transformers import GenerationMixin, PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
Qwen2_5OmniTextConfig,
)
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
Qwen2_5OmniThinkerTextModel,
)
from transformers.utils import can_return_tuple
from .configuration_midashenglm import DashengConfig, MiDashengLMConfig
_Tuple2 = Union[int, Tuple[int, int], Sequence[int]]
def _resolve_tuple2(x: _Tuple2) -> Tuple[int, int]:
if isinstance(x, collections.abc.Sequence):
assert len(x) == 2, (
f"Expected a sequence of length 2, got {x} with length {len(x)}"
)
return cast(Tuple[int, int], tuple(x))
return (x, x)
class AudioPatchEmbed(nn.Module):
def __init__(
self,
input_size: _Tuple2 = 64,
patch_size: _Tuple2 = 16,
patch_stride: _Tuple2 = 16,
in_chans: int = 1,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = False,
):
super().__init__()
self.input_size = _resolve_tuple2(input_size)
self.patch_size = _resolve_tuple2(patch_size)
self.patch_stride = _resolve_tuple2(patch_stride)
self.grid_size = (
self.input_size[0] // self.patch_stride[0],
self.input_size[1] // self.patch_stride[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=self.patch_size,
stride=self.patch_stride,
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
if self.flatten:
x = torch.permute(
torch.flatten(x, 2, 3), (0, 2, 1)
) # rearrange(x, "b c f t -> b (f t) c")
x = self.norm(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class DashengMlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
drop: float = 0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DashengAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
causal: bool = False,
):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.causal = causal
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
# if mask is not None:
# # Mask is a tensor of shape [B, T, T]
# # Different from self.causal == True, the mask might be something like:
# # [False, False, True]
# # [False, False, True]
# # [True, True, True]
# # We use -inf to pad here, since if we would pad by any number, the entries at rows only containing
# # [True, True, True] would lead to weights such as: [0.33,0.33,0.33], which is not correct
if self.causal:
mask_value = -torch.finfo(attn.dtype).max
i, j = attn.shape[-2:]
mask = torch.ones(i, j, device=q.device, dtype=torch.bool).triu(j - i + 1)
attn = attn.masked_fill(mask, mask_value)
if mask is not None:
# mask value as the lowest possible value in fp32
mask_value = torch.finfo(attn.dtype).min
# Mask is of shape [1, SRC_LEN]
attn_mask = mask[:, None, None, :].expand(B, 1, N, N)
# Mask should be of shape
# [B,1,Target_len, Source_len]
attn = attn.masked_fill(attn_mask, mask_value)
attn = attn.softmax(dim=-1)
attn = torch.nan_to_num(attn)
# Only for the case that a mask with all True entries on a row is passed.
# attn = torch.nan_to_num(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DashengBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = DashengAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = DashengMlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
drop=drop,
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
# Kwargs usually has a mask parameter that is passed to Attention
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = x + self.ls1(self.attn(self.norm1(x), mask))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class DashengAudioTransformer(PreTrainedModel):
config_class = DashengConfig
supports_gradient_checkpointing = True
def __init__(self, config: DashengConfig):
super().__init__(config)
self.target_length = config.target_length
self.embed_dim = config.embed_dim
self.hop_length = config.hop_length
self.gradient_checkpointing = False
self.front_end = nn.Sequential(
audio_transforms.MelSpectrogram(
f_min=config.f_min,
f_max=config.f_max,
center=config.center,
win_length=config.win_length,
hop_length=config.hop_length,
sample_rate=config.sample_rate,
n_fft=config.n_fft,
n_mels=config.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
self.patch_embed = AudioPatchEmbed(
input_size=(config.n_mels, config.target_length),
embed_dim=config.embed_dim,
in_chans=config.input_channels,
patch_size=config.patch_size,
flatten=False,
patch_stride=config.patch_stride,
)
self.time_pos_embed = nn.Parameter(
torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02
)
self.freq_pos_embed = nn.Parameter(
torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
)
self.pos_drop = nn.Dropout(p=config.drop_rate)
self.blocks = nn.ModuleList(
DashengBlock(
dim=config.embed_dim,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
init_values=config.init_values,
drop=config.drop_rate,
attn_drop=config.attn_drop_rate,
)
for i in range(config.depth)
)
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
self.post_init()
def forward_features(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
t = x.shape[-1]
x = x + self.time_pos_embed[:, :, :, :t]
x = (
x + self.freq_pos_embed[:, :, :, :]
) # Just to support __getitem__ in posembed
x = torch.permute(
torch.flatten(x, 2, 3), (0, 2, 1)
) # rearrange(x, "b c f t -> b (f t) c")
x = self.pos_drop(x)
for block in self.blocks:
if self.gradient_checkpointing and self.training:
x = self._gradient_checkpointing_func(block, x, mask)
else:
x = block(x, mask)
x = self.norm(x)
return x
def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor:
batch_size = len(lengths)
idx = torch.arange(max_length, device=lengths.device)
idx = idx.repeat(batch_size).view(batch_size, max_length)
mask = (idx < lengths.unsqueeze(-1)).bool()
return mask
def forward(
self,
x: torch.Tensor,
x_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = self.front_end(x)
target_length_in_patches = self.target_length // 4
x = x.unsqueeze(1)
x = torch.permute(x, (0, 2, 1, 3))
x = self.init_bn(x)
x = torch.permute(x, (0, 2, 1, 3))
x = self.patch_embed(x)
t = x.shape[-1]
input_splits = x.split(target_length_in_patches, dim=-1)
if x_length is not None:
assert len(x_length) == len(x), (
"batchsizes of input x and x_length need to be same"
)
assert x_length.ndim == 1, "Lengths are of size (B,)"
scaled_lengths = (x_length / (self.hop_length * 4)).long()
mask = self._to_mask(max_length=t, lengths=scaled_lengths)
split_masks = mask.logical_not().split(target_length_in_patches, dim=-1)
else:
mask = None
split_masks = [None] * len(input_splits)
outputs = []
for split_x, split_mask in zip(input_splits, split_masks):
forward_kwargs = {}
forward_kwargs["mask"] = split_mask
split_x = self.forward_features(split_x, **forward_kwargs)
outputs.append(split_x)
x = torch.cat(outputs, dim=1)
return x, mask
class AudioProjectorSubsample(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
downsample_rate=5,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.k = downsample_rate
self.net = nn.Sequential(
nn.Linear(in_dim * self.k, out_dim, dtype=dtype),
nn.GELU(),
nn.Linear(out_dim, out_dim, dtype=dtype),
)
def forward(self, x, mask=None):
batch_size, seq_len, dim = x.shape
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
if mask is not None:
mask = mask[:, :-num_frames_to_discard]
if mask is None:
mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device)
x = x.reshape(
batch_size, -1, self.k * dim
) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k)
x = self.net(x)
mask = mask.reshape(
batch_size, -1, self.k
) # rearrange(mask, "b (s k) -> b s k", k=self.k)
mask = mask.any(dim=-1).long()
return x, mask
@dataclass
class Qwen25OmniTextModelOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
config_class = Qwen2_5OmniTextConfig
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def __init__(self, config: Qwen2_5OmniTextConfig):
super().__init__(config)
self.model = Qwen2_5OmniThinkerTextModel._from_config(config)
self.lm_head = nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
)
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, Qwen25OmniTextModelOutput]:
if attention_mask is not None and position_ids is None:
position_ids = (
attention_mask.long()
.cumsum(dim=-1)
.masked_fill_(attention_mask == 0, 1)
- 1
)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=True,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = (
self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if labels is not None
else None
)
return Qwen25OmniTextModelOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MiDashengLMModel(PreTrainedModel):
config_class = MiDashengLMConfig
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
supports_gradient_checkpointing = True
def __init__(self, config: MiDashengLMConfig):
super().__init__(config)
self.audio_token_id = config.audio_token_id
self.audio_encoder = DashengAudioTransformer._from_config(
config.audio_encoder_config,
)
self.audio_projector = AudioProjectorSubsample(
self.audio_encoder.embed_dim,
config.text_config.hidden_size,
config.subsample_factor,
)
self.decoder = Qwen25OmniThinkerTextOnlyDecoder._from_config(
config.text_config,
attn_implementation=config._attn_implementation,
)
self.post_init()
def get_input_embeddings(self):
return self.decoder.model.embed_tokens
def get_output_embeddings(self):
return self.decoder.lm_head
def _forward_audio_encoder(
self,
audios: torch.Tensor,
audio_length: Optional[Iterable[int]],
) -> torch.Tensor:
encoder_out, encoder_atts = self.audio_encoder(audios, audio_length)
# audio projector
encoder_out, encoder_atts = self.audio_projector(encoder_out, encoder_atts)
return encoder_out
def _prepare_inputs_embeds(
self,
input_ids: Optional[torch.Tensor],
input_values: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor],
audio_length: Optional[Iterable[int]] = None,
) -> torch.Tensor:
if input_ids is not None:
if inputs_embeds is not None:
raise ValueError(
"Both `inputs_embeds` and `input_ids` are passed. Please pass only one of them."
)
inputs_embeds = cast(
torch.Tensor, self.decoder.model.embed_tokens(input_ids)
)
if input_values is not None:
if self.audio_token_id is None:
raise ValueError(
"Audio input is provided, but `audio_token_id` is not configured."
)
audio_embeddings = self._forward_audio_encoder(
input_values,
audio_length=audio_length,
).to(inputs_embeds.dtype)
audio_mask = (input_ids == self.audio_token_id).flatten()
diff = torch.diff(
audio_mask.long(),
prepend=torch.zeros(
(1,),
dtype=torch.long,
device=audio_mask.device,
),
)
audio_span_starts = (diff == 1).nonzero()
audio_span_ends = (diff == -1).nonzero()
embeds_view = inputs_embeds.view(-1, inputs_embeds.shape[-1])
for span_start, span_end, audio in zip(
audio_span_starts,
audio_span_ends,
audio_embeddings,
strict=True,
):
embeds_view[span_start:span_end] = audio[: span_end - span_start]
else:
if inputs_embeds is None:
raise ValueError(
"Either `input_ids` or `inputs_embeds` must be passed."
)
if input_values is not None:
raise ValueError(
"Cannot pass `input_values` when `inputs_embeds` is provided."
)
return inputs_embeds
def forward(
self,
input_ids: Optional[Tensor] = None,
input_values: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
audio_length: Optional[Iterable[int]] = None,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs: Any,
):
inputs_embeds = self._prepare_inputs_embeds(
input_ids=input_ids,
input_values=input_values,
inputs_embeds=inputs_embeds,
audio_length=audio_length,
)
return self.decoder(
input_ids=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
**kwargs,
)
def generate(
self,
input_ids: Optional[Tensor] = None,
input_values: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
audio_length: Optional[Iterable[int]] = None,
**kwargs,
):
inputs_embeds = self._prepare_inputs_embeds(
input_ids=input_ids,
input_values=input_values,
inputs_embeds=inputs_embeds,
audio_length=audio_length,
)
return self.decoder.generate(
inputs_embeds=inputs_embeds,
generation_config=kwargs.pop("generation_config", self.generation_config),
**kwargs,
)