File size: 5,565 Bytes
ed30f9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>
#include <torch/torch.h>
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#include <algorithm>
#include <dlfcn.h>
#include <mach-o/dyld.h>
#include <string>
#include <vector>
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}
static std::string getModuleDirectory() {
Dl_info dl_info;
if (dladdr((void *)getModuleDirectory, &dl_info)) {
std::string path(dl_info.dli_fname);
size_t pos = path.find_last_of('/');
if (pos != std::string::npos) {
return path.substr(0, pos);
}
}
return ".";
}
// Helper function to get conversion kernel name
static std::string getConvertKernelName(torch::ScalarType src_dtype, torch::ScalarType dst_dtype) {
std::string src_str, dst_str;
auto dtype_to_string = [](torch::ScalarType dtype) -> std::string {
switch (dtype) {
case torch::kFloat: return "float";
case torch::kHalf: return "half";
case torch::kBFloat16: return "bfloat16_t";
case torch::kUInt8: return "uchar";
default:
TORCH_CHECK(false, "Unsupported dtype for convert_fp8: ", dtype);
}
};
src_str = dtype_to_string(src_dtype);
dst_str = dtype_to_string(dst_dtype);
return "convert_fp8_" + src_str + "_to_" + dst_str;
}
void convert_fp8(torch::Tensor &dst_cache, torch::Tensor &src_cache,
const double scale, const std::string &kv_cache_dtype) {
// Validate input tensors
TORCH_CHECK(src_cache.device().is_mps() && dst_cache.device().is_mps(),
"Both tensors must be on MPS device");
TORCH_CHECK(src_cache.device() == dst_cache.device(),
"Source and destination tensors must be on the same device");
TORCH_CHECK(src_cache.numel() == dst_cache.numel(),
"Source and destination tensors must have the same number of elements");
TORCH_CHECK(src_cache.is_contiguous() && dst_cache.is_contiguous(),
"Both tensors must be contiguous");
const uint32_t num_elements = static_cast<uint32_t>(src_cache.numel());
if (num_elements == 0) {
return; // Nothing to convert
}
// Determine conversion kernel name
std::string kernel_name = getConvertKernelName(src_cache.scalar_type(), dst_cache.scalar_type());
@autoreleasepool {
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
TORCH_CHECK(stream, "Failed to get current MPS stream");
id<MTLDevice> device = stream->device();
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
// Load Metal library
std::string moduleDir = getModuleDirectory();
std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
NSString *metallibPathStr = [NSString stringWithUTF8String:metallibPath.c_str()];
NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
NSError *error = nil;
id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ",
error ? error.localizedDescription.UTF8String : "unknown error");
// Create kernel function
NSString *kernelNameStr = [NSString stringWithUTF8String:kernel_name.c_str()];
id<MTLFunction> fn = [lib newFunctionWithName:kernelNameStr];
TORCH_CHECK(fn, "Failed to find Metal kernel function: ", kernel_name);
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:fn error:&error];
TORCH_CHECK(pso, "Failed to create compute pipeline state: ",
error ? error.localizedDescription.UTF8String : "unknown error");
dispatch_queue_t q = stream->queue();
dispatch_sync(q, ^{
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
TORCH_CHECK(enc, "Failed to create compute encoder");
[enc setComputePipelineState:pso];
// Set buffers
[enc setBuffer:getMTLBufferStorage(src_cache)
offset:src_cache.storage_offset() * src_cache.element_size()
atIndex:0];
[enc setBuffer:getMTLBufferStorage(dst_cache)
offset:dst_cache.storage_offset() * dst_cache.element_size()
atIndex:1];
// Set scale parameter
float scale_f32 = static_cast<float>(scale);
id<MTLBuffer> scaleBuf = [device newBufferWithBytes:&scale_f32
length:sizeof(float)
options:MTLResourceStorageModeShared];
[enc setBuffer:scaleBuf offset:0 atIndex:2];
// Set num_elements parameter
id<MTLBuffer> numElementsBuf = [device newBufferWithBytes:&num_elements
length:sizeof(uint32_t)
options:MTLResourceStorageModeShared];
[enc setBuffer:numElementsBuf offset:0 atIndex:3];
// Dispatch threads
const uint32_t threads_per_threadgroup = std::min<uint32_t>(1024, num_elements);
const uint32_t threadgroups = (num_elements + threads_per_threadgroup - 1) / threads_per_threadgroup;
MTLSize threadsPerThreadgroup = MTLSizeMake(threads_per_threadgroup, 1, 1);
MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1);
[enc dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
[enc endEncoding];
});
stream->synchronize(at::mps::SyncType::COMMIT);
}
} |