File size: 639 Bytes
ed30f9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include "../torch-ext/torch_binding.h"
#import <Metal/Metal.h>
#include <torch/torch.h>

int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
  TORCH_CHECK(false, "get_device_attribute is not supported on Metal");
}

int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
  // On macOS you can have multiple GPUs; fetch the N-th one.
  NSArray<id<MTLDevice>> *all = MTLCopyAllDevices();
  TORCH_CHECK(device_id >= 0 && device_id < (int64_t)all.count,
              "Invalid Metal device index");

  id<MTLDevice> dev = all[device_id];
  return static_cast<int64_t>(dev.maxThreadgroupMemoryLength);
}