CharmAGX_G1 / core /data_architecture /attention_mechanism.py
GeminiFan207's picture
Create attention_mechanism.py
54b1762 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from typing import Optional, Tuple
import numpy as np
# Framework-specific imports
try:
from flash_attn import flash_attn_func
use_flash_attention = True
except ImportError:
use_flash_attention = False
try:
import jax
import jax.numpy as jnp
from jax import jit
use_jax = True
except ImportError:
use_jax = False
try:
import tensorflow as tf
use_tensorflow = True
except ImportError:
use_tensorflow = False
try:
import nemo.collections.nlp as nemo_nlp
use_nemo = True
except ImportError:
use_nemo = False
try:
import onnx
import onnxruntime as ort
use_onnx = True
except ImportError:
use_onnx = False
class MultiFrameworkAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, dropout: float = 0.1,
causal: bool = False, bias: bool = False,
framework: str = "pytorch"):
super().__init__()
self.framework = framework.lower()
self.num_heads = num_heads
self.head_dim = dim // num_heads
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dropout = dropout
self.causal = causal
self.scale = self.head_dim ** -0.5
# Framework-specific initialization
if self.framework == "pytorch":
self.qkv_proj = nn.Linear(dim, dim * 3, bias=bias)
self.out_proj = nn.Linear(dim, dim, bias=bias)
self.dropout_layer = nn.Dropout(dropout)
elif self.framework == "tensorflow" and use_tensorflow:
self.qkv_proj = tf.keras.layers.Dense(dim * 3, use_bias=bias)
self.out_proj = tf.keras.layers.Dense(dim, use_bias=bias)
self.dropout_layer = tf.keras.layers.Dropout(dropout)
elif self.framework == "jax" and use_jax:
# JAX uses functional paradigm, weights initialized in later
pass
elif self.framework == "nemo" and use_nemo:
# NeMo-specific initialization
pass
def forward(self, x, mask: Optional = None) -> any:
if self.framework == "pytorch":
return self._pytorch_forward(x, mask)
elif self.framework == "tensorflow" and use_tensorflow:
return self._tensorflow_forward(x, mask)
elif self.framework == "jax" and use_jax:
return self._jax_forward(x, mask)
elif self.framework == "nemo" and use_nemo:
return self._nemo_forward(x, mask)
else:
raise ValueError(f"Unsupported framework: {self.framework}")
def _pytorch_forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
B, S, D = x.size()
qkv = self.qkv_proj(x).reshape(B, S, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q, k, v = [rearrange(t, "b s h d -> b h s d") for t in (q, k, v)]
if use_flash_attention and mask is None:
out = flash_attn_func(
q, k, v,
dropout_p=self.dropout if self.training else 0,
causal=self.causal
)
else:
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout_layer(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h s d -> b s (h d)")
return self.out_proj(out)
def _tensorflow_forward(self, x, mask=None):
B, S, D = x.shape
qkv = self.qkv_proj(x)
qkv = tf.reshape(qkv, [B, S, 3, self.num_heads, self.head_dim])
q, k, v = tf.unstack(qkv, axis=2)
q = tf.transpose(q, [0, 2, 1, 3])
k = tf.transpose(k, [0, 2, 3, 1])
v = tf.transpose(v, [0, 2, 1, 3])
scores = tf.matmul(q, k) * self.scale
if mask is not None:
scores = tf.where(mask == 0, float('-inf'), scores)
attn = tf.nn.softmax(scores)
attn = self.dropout_layer(attn)
out = tf.matmul(attn, v)
out = tf.transpose(out, [0, 2, 1, 3])
out = tf.reshape(out, [B, S, D])
return self.out_proj(out)
def _jax_forward(self, x, mask=None):
B, S, D = x.shape
def attention_fn(params, x):
qkv = jnp.dot(x, params['qkv_proj'])
qkv = qkv.reshape(B, S, 3, self.num_heads, self.head_dim)
q, k, v = jnp.split(qkv, 3, axis=2)
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 3, 1))
v = jnp.transpose(v, (0, 2, 1, 3))
scores = jnp.matmul(q, k) * self.scale
if mask is not None:
scores = jnp.where(mask == 0, float('-inf'), scores)
attn = jax.nn.softmax(scores)
out = jnp.matmul(attn, v)
out = jnp.transpose(out, (0, 2, 1, 3))
out = out.reshape(B, S, D)
return jnp.dot(out, params['out_proj'])
# Dummy params for JAX (would need proper initialization in practice)
params = {
'qkv_proj': jnp.zeros((D, D * 3)),
'out_proj': jnp.zeros((D, D))
}
return jit(attention_fn)(params, x)
def _nemo_forward(self, x, mask=None):
# NeMo integration would typically use PyTorch backend
# This is a placeholder showing how it could be structured
if use_nemo:
nemo_model = nemo_nlp.models.TransformerEncoderModel(
hidden_size=D,
num_attention_heads=self.num_heads
)
return nemo_model(input_ids=x, attention_mask=mask)
return self._pytorch_forward(x, mask)
def to_onnx(self, dummy_input, output_path: str):
if not use_onnx:
raise ImportError("ONNX support not available")
if self.framework == "pytorch":
torch.onnx.export(
self,
dummy_input,
output_path,
opset_version=13,
input_names=['input'],
output_names=['output']
)
elif self.framework == "tensorflow" and use_tensorflow:
# Convert TF to ONNX would require tf2onnx
pass
elif self.framework == "jax" and use_jax:
# JAX to ONNX would require jax2tf then tf2onnx
pass
class TransformerBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ff_dim: int,
num_experts: int = 16, top_k: int = 2, dropout: float = 0.1,
framework: str = "pytorch"):
super().__init__()
self.framework = framework
self.attn_norm = nn.LayerNorm(dim) if framework == "pytorch" else None
self.attn = MultiFrameworkAttention(dim, num_heads, dropout, framework=framework)
self.ff_norm = nn.LayerNorm(dim) if framework == "pytorch" else None
# Add MoE and other components similarly
self.dropout = nn.Dropout(dropout) if framework == "pytorch" else None
def forward(self, x, mask: Optional = None):
if self.framework == "pytorch":
x = x + self.dropout(self.attn(self.attn_norm(x), mask))
# Add FF/MoE logic
return x
# Add other framework implementations
return x
# Usage example
if __name__ == "__main__":
# PyTorch example
model = MultiFrameworkAttention(dim=512, num_heads=8, framework="pytorch")
x = torch.randn(2, 64, 512)
output = model(x)
# Export to ONNX
if use_onnx:
model.to_onnx(x, "attention.onnx")
# TensorFlow example
if use_tensorflow:
tf_model = MultiFrameworkAttention(dim=512, num_heads=8, framework="tensorflow")
x_tf = tf.random.normal((2, 64, 512))
output_tf = tf_model(x_tf)