gpt-oss-metal-kernels / test /bf16-f32-embeddings.cc
Mohamed Mekkouri
commit evtn
95d28ad
#include <gtest/gtest.h>
#include <cstddef>
#include "embeddings-kernel-tester.hpp"
using gptoss::EmbeddingsKernelTester;
constexpr std::size_t kThreadgroupSize = 64;
TEST(BF16_F32_EMBEDDINGS, single_token_single_tile) {
EmbeddingsKernelTester()
.num_channels(kThreadgroupSize)
.threadgroup_size(kThreadgroupSize)
.TestBF16_F32();
}
TEST(BF16_F32_EMBEDDINGS, single_token_multi_tile) {
EmbeddingsKernelTester()
.num_channels(kThreadgroupSize * 4 + 16)
.threadgroup_size(kThreadgroupSize)
.TestBF16_F32();
}
TEST(BF16_F32_EMBEDDINGS, multiple_tokens) {
EmbeddingsKernelTester()
.num_channels(kThreadgroupSize * 4 + 16)
.num_tokens(3)
.threadgroup_size(kThreadgroupSize)
.TestBF16_F32();
}