|
# CUTLASS Epilogues |
|
|
|
## Introduction |
|
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. |
|
|
|
Currently, we only support symmetric quantization for weights, |
|
and symmetric and asymmetric quantization for activations. |
|
Both can be quantized per-tensor or per-channel (weights) / per-token (activations). |
|
|
|
There are 4 epilogues: |
|
1. ScaledEpilogue: symmetric quantization for activations, no bias. |
|
1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. |
|
1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. |
|
1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. |
|
|
|
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. |
|
Instead, if no bias is passed, the epilogue will use 0 as the bias. |
|
That induces a redundant addition operation (and runtime check), but the performance impact is minor. |
|
|
|
## Underlying Linear Algebra |
|
|
|
More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). |
|
|
|
If $` \widehat X `$ is the quantized $` X `$, our matrices become the following |
|
|
|
```math |
|
A = s_a (\widehat A - J_a z_a) |
|
``` |
|
```math |
|
B = s_b \widehat B |
|
``` |
|
```math |
|
D = A B + C |
|
``` |
|
```math |
|
D = s_a s_b \widehat D + C |
|
``` |
|
|
|
Here, D is the output of the GEMM, and C is the bias. |
|
A is the activations and supports asymmetric quantization, |
|
and B is the weights and only supports symmetric quantization. |
|
$ s_a $ and $s_b$ are the scales for activations and weights, respectively. |
|
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. |
|
Additional epilogues would be required to support asymmetric quantization for weights. |
|
|
|
Expanding further, we can calculate $` \widehat D `$ as follows: |
|
|
|
```math |
|
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B |
|
``` |
|
```math |
|
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) |
|
``` |
|
```math |
|
\widehat D = \widehat A \widehat B - z_a J_a \widehat B |
|
``` |
|
|
|
Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, |
|
and $` J_a \widehat B `$ is known ahead of time. |
|
Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. |
|
|
|
## Epilogues |
|
|
|
### ScaledEpilogue |
|
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. |
|
The output of the GEMM is: |
|
|
|
```math |
|
\widehat D = \widehat A \widehat B |
|
``` |
|
```math |
|
D = s_a s_b \widehat D |
|
``` |
|
```math |
|
D = s_a s_b \widehat A \widehat B |
|
``` |
|
|
|
Epilogue parameters: |
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). |
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). |
|
|
|
### ScaledEpilogueBias |
|
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. |
|
The output of the GEMM is: |
|
|
|
```math |
|
\widehat D = \widehat A \widehat B |
|
``` |
|
```math |
|
D = s_a s_b \widehat D + C |
|
``` |
|
```math |
|
D = s_a s_b \widehat A \widehat B + C |
|
``` |
|
|
|
|
|
Epilogue parameters: |
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). |
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). |
|
- `bias` is the bias, is always per-channel (row-vector). |
|
|
|
### ScaledEpilogueAzp |
|
This epilogue computes the asymmetric per-tensor quantization for activations with bias. |
|
The output of the GEMM is: |
|
|
|
```math |
|
\widehat D = \widehat A \widehat B - z_a J_a \widehat B |
|
``` |
|
```math |
|
D = s_a s_b \widehat D + C |
|
``` |
|
```math |
|
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C |
|
``` |
|
|
|
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. |
|
That is precomputed and stored in `azp_with_adj` as a row-vector. |
|
|
|
Epilogue parameters: |
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). |
|
- Generally this will be per-tensor as the zero-points are per-tensor. |
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). |
|
- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). |
|
- `bias` is the bias, is always per-channel (row-vector). |
|
|
|
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. |
|
|
|
### ScaledEpilogueAzpPerToken |
|
This epilogue computes the asymmetric per-token quantization for activations with bias. |
|
|
|
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. |
|
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. |
|
|
|
Epilogue parameters: |
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). |
|
- Generally this will be per-token as the zero-points are per-token. |
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). |
|
- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). |
|
- `azp` is the zero-point (`z_a`), is per-token (column-vector). |
|
- `bias` is the bias, is always per-channel (row-vector). |
|
|
|
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. |
|
|
|
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): |
|
``` |
|
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias |
|
``` |
|
|