|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from typing import Optional, Tuple |
|
import numpy as np |
|
|
|
|
|
try: |
|
from flash_attn import flash_attn_func |
|
use_flash_attention = True |
|
except ImportError: |
|
use_flash_attention = False |
|
|
|
try: |
|
import jax |
|
import jax.numpy as jnp |
|
from jax import jit |
|
use_jax = True |
|
except ImportError: |
|
use_jax = False |
|
|
|
try: |
|
import tensorflow as tf |
|
use_tensorflow = True |
|
except ImportError: |
|
use_tensorflow = False |
|
|
|
try: |
|
import nemo.collections.nlp as nemo_nlp |
|
use_nemo = True |
|
except ImportError: |
|
use_nemo = False |
|
|
|
try: |
|
import onnx |
|
import onnxruntime as ort |
|
use_onnx = True |
|
except ImportError: |
|
use_onnx = False |
|
|
|
class MultiFrameworkAttention(nn.Module): |
|
def __init__(self, dim: int, num_heads: int, dropout: float = 0.1, |
|
causal: bool = False, bias: bool = False, |
|
framework: str = "pytorch"): |
|
super().__init__() |
|
self.framework = framework.lower() |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
assert dim % num_heads == 0, "dim must be divisible by num_heads" |
|
self.dropout = dropout |
|
self.causal = causal |
|
self.scale = self.head_dim ** -0.5 |
|
|
|
|
|
if self.framework == "pytorch": |
|
self.qkv_proj = nn.Linear(dim, dim * 3, bias=bias) |
|
self.out_proj = nn.Linear(dim, dim, bias=bias) |
|
self.dropout_layer = nn.Dropout(dropout) |
|
elif self.framework == "tensorflow" and use_tensorflow: |
|
self.qkv_proj = tf.keras.layers.Dense(dim * 3, use_bias=bias) |
|
self.out_proj = tf.keras.layers.Dense(dim, use_bias=bias) |
|
self.dropout_layer = tf.keras.layers.Dropout(dropout) |
|
elif self.framework == "jax" and use_jax: |
|
|
|
pass |
|
elif self.framework == "nemo" and use_nemo: |
|
|
|
pass |
|
|
|
def forward(self, x, mask: Optional = None) -> any: |
|
if self.framework == "pytorch": |
|
return self._pytorch_forward(x, mask) |
|
elif self.framework == "tensorflow" and use_tensorflow: |
|
return self._tensorflow_forward(x, mask) |
|
elif self.framework == "jax" and use_jax: |
|
return self._jax_forward(x, mask) |
|
elif self.framework == "nemo" and use_nemo: |
|
return self._nemo_forward(x, mask) |
|
else: |
|
raise ValueError(f"Unsupported framework: {self.framework}") |
|
|
|
def _pytorch_forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: |
|
B, S, D = x.size() |
|
qkv = self.qkv_proj(x).reshape(B, S, 3, self.num_heads, self.head_dim) |
|
q, k, v = qkv.unbind(dim=2) |
|
q, k, v = [rearrange(t, "b s h d -> b h s d") for t in (q, k, v)] |
|
|
|
if use_flash_attention and mask is None: |
|
out = flash_attn_func( |
|
q, k, v, |
|
dropout_p=self.dropout if self.training else 0, |
|
causal=self.causal |
|
) |
|
else: |
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
if mask is not None: |
|
scores = scores.masked_fill(mask == 0, float('-inf')) |
|
attn = F.softmax(scores, dim=-1) |
|
attn = self.dropout_layer(attn) |
|
out = torch.matmul(attn, v) |
|
|
|
out = rearrange(out, "b h s d -> b s (h d)") |
|
return self.out_proj(out) |
|
|
|
def _tensorflow_forward(self, x, mask=None): |
|
B, S, D = x.shape |
|
qkv = self.qkv_proj(x) |
|
qkv = tf.reshape(qkv, [B, S, 3, self.num_heads, self.head_dim]) |
|
q, k, v = tf.unstack(qkv, axis=2) |
|
q = tf.transpose(q, [0, 2, 1, 3]) |
|
k = tf.transpose(k, [0, 2, 3, 1]) |
|
v = tf.transpose(v, [0, 2, 1, 3]) |
|
|
|
scores = tf.matmul(q, k) * self.scale |
|
if mask is not None: |
|
scores = tf.where(mask == 0, float('-inf'), scores) |
|
attn = tf.nn.softmax(scores) |
|
attn = self.dropout_layer(attn) |
|
out = tf.matmul(attn, v) |
|
out = tf.transpose(out, [0, 2, 1, 3]) |
|
out = tf.reshape(out, [B, S, D]) |
|
return self.out_proj(out) |
|
|
|
def _jax_forward(self, x, mask=None): |
|
B, S, D = x.shape |
|
def attention_fn(params, x): |
|
qkv = jnp.dot(x, params['qkv_proj']) |
|
qkv = qkv.reshape(B, S, 3, self.num_heads, self.head_dim) |
|
q, k, v = jnp.split(qkv, 3, axis=2) |
|
q = jnp.transpose(q, (0, 2, 1, 3)) |
|
k = jnp.transpose(k, (0, 2, 3, 1)) |
|
v = jnp.transpose(v, (0, 2, 1, 3)) |
|
|
|
scores = jnp.matmul(q, k) * self.scale |
|
if mask is not None: |
|
scores = jnp.where(mask == 0, float('-inf'), scores) |
|
attn = jax.nn.softmax(scores) |
|
out = jnp.matmul(attn, v) |
|
out = jnp.transpose(out, (0, 2, 1, 3)) |
|
out = out.reshape(B, S, D) |
|
return jnp.dot(out, params['out_proj']) |
|
|
|
|
|
params = { |
|
'qkv_proj': jnp.zeros((D, D * 3)), |
|
'out_proj': jnp.zeros((D, D)) |
|
} |
|
return jit(attention_fn)(params, x) |
|
|
|
def _nemo_forward(self, x, mask=None): |
|
|
|
|
|
if use_nemo: |
|
nemo_model = nemo_nlp.models.TransformerEncoderModel( |
|
hidden_size=D, |
|
num_attention_heads=self.num_heads |
|
) |
|
return nemo_model(input_ids=x, attention_mask=mask) |
|
return self._pytorch_forward(x, mask) |
|
|
|
def to_onnx(self, dummy_input, output_path: str): |
|
if not use_onnx: |
|
raise ImportError("ONNX support not available") |
|
|
|
if self.framework == "pytorch": |
|
torch.onnx.export( |
|
self, |
|
dummy_input, |
|
output_path, |
|
opset_version=13, |
|
input_names=['input'], |
|
output_names=['output'] |
|
) |
|
elif self.framework == "tensorflow" and use_tensorflow: |
|
|
|
pass |
|
elif self.framework == "jax" and use_jax: |
|
|
|
pass |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, dim: int, num_heads: int, ff_dim: int, |
|
num_experts: int = 16, top_k: int = 2, dropout: float = 0.1, |
|
framework: str = "pytorch"): |
|
super().__init__() |
|
self.framework = framework |
|
self.attn_norm = nn.LayerNorm(dim) if framework == "pytorch" else None |
|
self.attn = MultiFrameworkAttention(dim, num_heads, dropout, framework=framework) |
|
self.ff_norm = nn.LayerNorm(dim) if framework == "pytorch" else None |
|
|
|
self.dropout = nn.Dropout(dropout) if framework == "pytorch" else None |
|
|
|
def forward(self, x, mask: Optional = None): |
|
if self.framework == "pytorch": |
|
x = x + self.dropout(self.attn(self.attn_norm(x), mask)) |
|
|
|
return x |
|
|
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
model = MultiFrameworkAttention(dim=512, num_heads=8, framework="pytorch") |
|
x = torch.randn(2, 64, 512) |
|
output = model(x) |
|
|
|
|
|
if use_onnx: |
|
model.to_onnx(x, "attention.onnx") |
|
|
|
|
|
if use_tensorflow: |
|
tf_model = MultiFrameworkAttention(dim=512, num_heads=8, framework="tensorflow") |
|
x_tf = tf.random.normal((2, 64, 512)) |
|
output_tf = tf_model(x_tf) |