diff --git a/build/torch25-cxx11-cu118-x86_64-linux/attention/__init__.py b/build/torch25-cxx11-cu118-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch25-cxx11-cu118-x86_64-linux/attention/_attention_6yvgebnqctora.abi3.so b/build/torch25-cxx11-cu118-x86_64-linux/attention/_attention_6yvgebnqctora.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..29733cfb726d11a1d278fb0f3679c010cf5210e2 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/attention/_attention_6yvgebnqctora.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aee255dc2618e23d4e2076ff3d16c4fbd12d63742fde84252cfb6bfe55c5376e +size 78886392 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/attention/_custom_ops.py b/build/torch25-cxx11-cu118-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch25-cxx11-cu118-x86_64-linux/attention/_ops.py b/build/torch25-cxx11-cu118-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1379d7cc10c5fafa877e3ea73be33d3eed57b449 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_6yvgebnqctora +ops = torch.ops._attention_6yvgebnqctora + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_6yvgebnqctora::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu118-x86_64-linux/attention/platforms.py b/build/torch25-cxx11-cu118-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx11-cu121-x86_64-linux/attention/__init__.py b/build/torch25-cxx11-cu121-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch25-cxx11-cu121-x86_64-linux/attention/_attention_4jg2igd54wzge.abi3.so b/build/torch25-cxx11-cu121-x86_64-linux/attention/_attention_4jg2igd54wzge.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a58d380aa758b8e6842e89013229bee3711286ef --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/attention/_attention_4jg2igd54wzge.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22599ebe9d209fcc82068054caf39f93e6828bb3889b344e655fee50e7a98864 +size 75398808 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/attention/_custom_ops.py b/build/torch25-cxx11-cu121-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch25-cxx11-cu121-x86_64-linux/attention/_ops.py b/build/torch25-cxx11-cu121-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9dee16955e9d988953733fae4e743d92886c92b1 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_4jg2igd54wzge +ops = torch.ops._attention_4jg2igd54wzge + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_4jg2igd54wzge::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu121-x86_64-linux/attention/platforms.py b/build/torch25-cxx11-cu121-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx11-cu124-x86_64-linux/attention/__init__.py b/build/torch25-cxx11-cu124-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch25-cxx11-cu124-x86_64-linux/attention/_attention_syg6kbhkhc4xk.abi3.so b/build/torch25-cxx11-cu124-x86_64-linux/attention/_attention_syg6kbhkhc4xk.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..369150e0964eaca52c0c7906addf9f18d8ec7270 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/attention/_attention_syg6kbhkhc4xk.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42a3b2b450b7e284694e8e6d7398627b977d1e5da12bb79d93c6009c192922f9 +size 75568320 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/attention/_custom_ops.py b/build/torch25-cxx11-cu124-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch25-cxx11-cu124-x86_64-linux/attention/_ops.py b/build/torch25-cxx11-cu124-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0bac0403831e313bcf9cbab1a35c2cbe4d5ef08f --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_syg6kbhkhc4xk +ops = torch.ops._attention_syg6kbhkhc4xk + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_syg6kbhkhc4xk::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu124-x86_64-linux/attention/platforms.py b/build/torch25-cxx11-cu124-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx98-cu118-x86_64-linux/attention/__init__.py b/build/torch25-cxx98-cu118-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch25-cxx98-cu118-x86_64-linux/attention/_attention_hhzgzhvc7zviy.abi3.so b/build/torch25-cxx98-cu118-x86_64-linux/attention/_attention_hhzgzhvc7zviy.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..05529e8bcee239db92984acb3e19926697c64a3f --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/attention/_attention_hhzgzhvc7zviy.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffad04fc3e82be818bafed25c1be1e9e6145f99eb0ef89ab87ef5ab8c8366f9b +size 78850608 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/attention/_custom_ops.py b/build/torch25-cxx98-cu118-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch25-cxx98-cu118-x86_64-linux/attention/_ops.py b/build/torch25-cxx98-cu118-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..270fd3d0005a3e44dc6625c3ab4948a7fa7892bb --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_hhzgzhvc7zviy +ops = torch.ops._attention_hhzgzhvc7zviy + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_hhzgzhvc7zviy::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu118-x86_64-linux/attention/platforms.py b/build/torch25-cxx98-cu118-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx98-cu121-x86_64-linux/attention/__init__.py b/build/torch25-cxx98-cu121-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch25-cxx98-cu121-x86_64-linux/attention/_attention_gbi5gm244waic.abi3.so b/build/torch25-cxx98-cu121-x86_64-linux/attention/_attention_gbi5gm244waic.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..cb6cccabe445cbf7bfd797b4645300e5a2a4ec38 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/attention/_attention_gbi5gm244waic.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ed1c9c4c080a10f7d7f8c18e8e96613020851f769a1bf5e2b92bf19b4e01fb6 +size 75359216 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/attention/_custom_ops.py b/build/torch25-cxx98-cu121-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch25-cxx98-cu121-x86_64-linux/attention/_ops.py b/build/torch25-cxx98-cu121-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a517876400c08f9800107c61d6ca3f57e0bdc2e6 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_gbi5gm244waic +ops = torch.ops._attention_gbi5gm244waic + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_gbi5gm244waic::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu121-x86_64-linux/attention/platforms.py b/build/torch25-cxx98-cu121-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch25-cxx98-cu124-x86_64-linux/attention/__init__.py b/build/torch25-cxx98-cu124-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch25-cxx98-cu124-x86_64-linux/attention/_attention_ill75rmpj7yds.abi3.so b/build/torch25-cxx98-cu124-x86_64-linux/attention/_attention_ill75rmpj7yds.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..bf93abf5555357ad397844421fcfc66ae0743166 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/attention/_attention_ill75rmpj7yds.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f263e022ef503e7fffcbc15ef59e515b84889d4c473b9113f3fea292725b9e37 +size 75532912 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/attention/_custom_ops.py b/build/torch25-cxx98-cu124-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch25-cxx98-cu124-x86_64-linux/attention/_ops.py b/build/torch25-cxx98-cu124-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f49b90de8bda122b2049bf57f5012b60e05364fe --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_ill75rmpj7yds +ops = torch.ops._attention_ill75rmpj7yds + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_ill75rmpj7yds::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu124-x86_64-linux/attention/platforms.py b/build/torch25-cxx98-cu124-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/attention/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/attention/_attention_6qe5ft3kiteru.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/attention/_attention_6qe5ft3kiteru.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0bbd1dc682174c9d7fba2ee7426e1183e668ab79 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/attention/_attention_6qe5ft3kiteru.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e66eca8e825e5cee2dc18c1235319a4e5b1372d843cab74660e8d94792e02f7c +size 78857896 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/attention/_custom_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/attention/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b2a39308433746718b31f0d9830b27f72f5242 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_6qe5ft3kiteru +ops = torch.ops._attention_6qe5ft3kiteru + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_6qe5ft3kiteru::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/attention/platforms.py b/build/torch26-cxx11-cu118-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/attention/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/attention/_attention_ftq3cjdxqfw4m.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/attention/_attention_ftq3cjdxqfw4m.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d7fa42c3682924a46e9c5b4a7e847a6b4415c5c8 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/attention/_attention_ftq3cjdxqfw4m.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:913ba8f5166dc4e84ed8a2da4b1dc44c178a93eeb16aae9782176fb089a459a7 +size 75552112 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/attention/_custom_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/attention/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..27b44593d2252bfe5399c8dcd883aa497223f158 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_ftq3cjdxqfw4m +ops = torch.ops._attention_ftq3cjdxqfw4m + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_ftq3cjdxqfw4m::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/attention/platforms.py b/build/torch26-cxx11-cu124-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/attention/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/attention/_attention_lkibbjh726iwm.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/attention/_attention_lkibbjh726iwm.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..4a4cccfd49090ac213bbf562a9c4bb2ff2920eb0 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/attention/_attention_lkibbjh726iwm.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91380eebc7db2ff85f92e687d388055f210123bac602a6bc273172834bf49012 +size 75376640 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/attention/_custom_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/attention/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ac89377661ed1c5f2eca40cf199a15209af0c05c --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_lkibbjh726iwm +ops = torch.ops._attention_lkibbjh726iwm + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_lkibbjh726iwm::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/attention/platforms.py b/build/torch26-cxx11-cu126-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/attention/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/attention/_attention_vbhagz24hyij6.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/attention/_attention_vbhagz24hyij6.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..4d87629674e87a746aaec4ccadb26bb2a72f2d43 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/attention/_attention_vbhagz24hyij6.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3746697abeeb7f829661c0912ccb36a7f7bb16c1f9eb7f14b1ee5e52c93ec055 +size 78830632 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/attention/_custom_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/attention/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2f05f1ffd05c49971dfc9b45971efb5a055c7e52 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_vbhagz24hyij6 +ops = torch.ops._attention_vbhagz24hyij6 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_vbhagz24hyij6::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/attention/platforms.py b/build/torch26-cxx98-cu118-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/attention/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/attention/_attention_sfjvhlixssyce.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/attention/_attention_sfjvhlixssyce.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ee6153972f28bd997e1fc4a7eaaf425fd5adc918 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/attention/_attention_sfjvhlixssyce.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68ecca9bc82b5fb7bf290f0c91ff86b65d25f7c5534f607b98bec8557922cf84 +size 75521080 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/attention/_custom_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/attention/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..530d483cdf8243f6c863ca49c0e87018634e69d0 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_sfjvhlixssyce +ops = torch.ops._attention_sfjvhlixssyce + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_sfjvhlixssyce::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/attention/platforms.py b/build/torch26-cxx98-cu124-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/attention/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/attention/_attention_g7oqtcveiuapk.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/attention/_attention_g7oqtcveiuapk.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fe58b4ce4158bf5ee55371329396ac8e573cfc85 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/attention/_attention_g7oqtcveiuapk.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:adb35fabc23d8caa55f061d32ee48688c32e3efa0b4bf9aaed58cc59620e422c +size 75341504 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/attention/_custom_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/attention/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1e504e67dd25c4aa79bcc509316f3f23e6e3e6ef --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _attention_g7oqtcveiuapk +ops = torch.ops._attention_g7oqtcveiuapk + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_attention_g7oqtcveiuapk::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/attention/platforms.py b/build/torch26-cxx98-cu126-x86_64-linux/attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06132e74cd7fb634044a76e528979b02a3559b --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/attention/platforms.py @@ -0,0 +1,62 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()