# Ethically sourced from https://github.com/xjdr-alt/entropix import torch def precompute_freqs_cis( dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim)) t = torch.arange(end, dtype=dtype).unsqueeze(1) freqs = t * freqs.unsqueeze(0) freqs = torch.exp(1j * freqs) return torch.stack([freqs.real, freqs.imag], dim=-1) # rope.py import torch def apply_rotary_emb( x: torch.Tensor, freqs_cis: torch.Tensor, position_ids: torch.Tensor, num_heads: int, rot_dim: int = 32, interleave: bool = False, ) -> torch.Tensor: """ RoPE as used in the original moondream2 text stack: x: (B, H, T, D) freqs_cis: (max_T, rot_dim//2, 2) where [...,0]=cos, [...,1]=sin position_ids: (T,) or (B,T) returns x with first rot_dim dims rotated. """ assert rot_dim == freqs_cis.shape[-2] * 2 assert num_heads == x.shape[1] B, H, T, D = x.shape rd = min(rot_dim, D) x_rot, x_pass = x[..., :rd], x[..., rd:] # split real/imag parts depending on layout if interleave: xr = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0] xi = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1] else: d = x_rot.shape[-1] // 2 xr, xi = x_rot[..., :d], x_rot[..., d:] # gather cos/sin for these positions if position_ids.dim() == 2 and position_ids.size(0) == B: freq = freqs_cis[position_ids] # (B, T, rd//2, 2) else: # (T,) or scalar freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1) rot_half = rd // 2 cos = freq[..., 0][..., :rot_half].unsqueeze(1).to(x.dtype) # (B,1,T,rot_half) sin = freq[..., 1][..., :rot_half].unsqueeze(1).to(x.dtype) # complex multiply yr = xr * cos - xi * sin yi = xr * sin + xi * cos y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd) return torch.cat([y, x_pass], dim=-1)