CUDA 笔记集合
1. cutlass/CuTe GEMM 中矩阵的存储方式 NT / TN / NN / TT
1.1. 背景
BLAS 的约定是:所有矩阵一律按 column-major 存储,然后用 transA/transB 标志告诉 BLAS 要不要对它做转置:
\[C = \alpha \cdot op(A) \cdot op(B) + \beta \cdot C\]其中:
- 当
transX为N时:$op(X) = X$,当transX为T时:$op(X) = X^T$。 - 乘法要求 $op(A)$ 是 $M \times K$,$op(B)$ 是 $K \times N$。
1.2. GEMM 命名含义
CuTe给矩阵做了一个约定:A(M, K),B(N, K),C(M, N),即:
- A 矩阵:(M,K) – M 行 K 列
- B 矩阵:(N,K) – N 行 K 列(不同于 BLAS 及其他典型约定)
- C 矩阵:(M,N) – M 行 N 列
即 CuTe 对 B 的约定,默认即为转置形式,即$B^T$,正好与 BLAS 约定形成转置关系。
由于 CuTe 中对 A/B/C 的约束,导致在调用 BLAS 的时候,通过设置主序来表达转置关系:
- 针对 A,如果是 N,则使用 CuTe 表示的时候,A 是 (M, K),column-major;如果是 T,则使用 CuTe 表示的时候,A 是 (M, K),row-major。
- 针对 B,如果是 N,则使用 CuTe 表示的时候,B 是 (N, K),row-major;如果是 T,则使用 CuTe 表示的时候,B 是 (N, K),column-major。
- 即 B 在 CuTe 中的表示,与 A 在 CuTe 中的表示,转换规律正好相反。
CuTe 对 A/B/C 的约定,参见https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/0x_gemm_tutorial.md#the-full-tensors-shapes-strides-and-data。
gemm_nt
template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_nt(int m, int n, int k, Alpha alpha, TA const* A,
int ldA, TB const* B, int ldB, Beta beta, TC* C,
int ldC, cudaStream_t stream)
nt含义是:
- 此时,BLAS 约定 $A$ 是 (M, K),column-major。此时 CuTe 约定格式与 BLAS 约定一致。
- 此时,BLAS 的约定 $B^T$ 是 (K, N)。正好与 CuTe 约定的 B 矩阵形状一致,且是 column-major。
最终得到:
strider_A = cute::make_stride(cute::_1, ldA); // column-major
shape_A = cute::make_shape(M, K);
stride_B = cute::make_stride(cute::_1, ldB); // column-major
shape_B = cute::make_shape(N, K);
stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);
gemm_tn
template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_tn(int m, int n, int k, Alpha alpha, TA const* A,
int ldA, TB const* B, int ldB, Beta beta, TC* C,
int ldC, cudaStream_t stream)
- $A^T$是(M, K)。按照 CuTe 约定,CuTe 要表达 $A^T$,只能是 (M, K) + row-major。
- 由于 CuTe 对 B 的约定导致其存储格式与 BLAS 约定形成转置关系,B 的的存储格式是 (N, K) + row-major。
得到:
stride_A = cute::make_stride(ldA, cute::_1); // row-major
shape_A = cute::make_shape(M, K);
stride_B = cute::make_stride(ldB, cute::_1); // row-major
shape_B = cute::make_shape(N, K);
stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);
gemm_nn
template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_nn(int m, int n, int k, Alpha alpha, TA const* A,
int ldA, TB const* B, int ldB, Beta beta, TC* C,
int ldC, cudaStream_t stream)
- $A$ 是 (M, K),column-major。CuTe 约定格式与 BLAS 约定一致。
- $B$ 是 (N, K),row-major。由于 CuTe 对 B 的约定导致其存储格式与 BLAS 约定形成转置关系,B 的的存储格式是 (N, K) + row-major。
得到:
stride_A = cute::make_stride(cute::_1, ldA); // column-major
shape_A = cute::make_shape(M, K);
stride_B = cute::make_stride(ldB, cute::_1); // row-major
shape_B = cute::make_shape(N, K);
stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);
gemm_tt
template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_tt(int m, int n, int k, Alpha alpha, TA const* A,
int ldA, TB const* B, int ldB, Beta beta, TC* C,
int ldC, cudaStream_t stream)
- $A^T$ 是 (M, K),row-major。CuTe 约定格式与 BLAS 约定形成转置关系,因此 A 的存储格式是 (M, K) + row-major。
- $B^T$ 是 (K, N),column-major。CuTe 约定格式与 BLAS 约定一致,即 CuTe 默认表达 BLAS 的转置形式。
得到:
stride_A = cute::make_stride(ldA, cute::_1); // row-major
shape_A = cute::make_shape(M, K);
stride_B = cute::make_stride(cute::_1, ldB); // column-major
shape_B = cute::make_shape(N, K);
stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);
1.3. 内存访问效率分析–访存合并
在划分 A/B 的的过程中,一般按照 M 方向划分 tile(针对 A),或者按照 N 方向划分 tile(针对 B)。比如如下 Thread-Value Layout 划分:
// Define thread layouts.
auto const thread_shape_A{cute::make_shape(cute::Int<16>{}, cute::Int<8>{})}; // (THR_M, THR_K)
auto const thread_shape_B{cute::make_shape(cute::Int<16>{}, cute::Int<8>{})}; // (THR_N, THR_K)
auto const thread_shape_C{cute::make_shape(cute::Int<32>{}, cute::Int<4>{})}; // (THR_M, THR_N)
auto const thread_stride_A{cute::make_stride(cute::Int<1>{}, cute::size<0>(thread_shape_A))}; // column-major
auto const thread_stride_B{cute::make_stride(cute::Int<1>{}, cute::size<0>(thread_shape_B))}; // column-major
auto const thread_stride_C{cute::make_stride(cute::Int<1>{}, cute::size<0>(thread_shape_C))}; // column-major
auto const thread_layout_A{cute::make_layout(thread_shape_A, thread_stride_A)}; // (THR_M, THR_K)
auto const thread_layout_B{cute::make_layout(thread_shape_B, thread_stride_B)}; // (THR_N, THR_K)
auto const thread_layout_C{cute::make_layout(thread_shape_C, thread_stride_C)}; // (THR_M, THR_N)
此时,得到:$\text{thr_id} = m \times 1 + k \times 16$。
即,同一个 warp 内的 32 个连续线程,它们的 thread ID 沿第一个维度(M 或 N)连续变化。所以:
加载 A tile 时:warp 内线程沿 M 维度连续 → 如果 A 在 M 维度内存连续(column-major, stride=1),就是合并访存 ✅ 加载 B tile 时:warp 内线程沿 N 维度连续 → 如果 B 在 N 维度内存连续(column-major, stride=1),就是合并访存 ✅
对这几种形式的 GEMM,可以得到其访存能否合并:
| 变体 | A 的 M 维度 | B 的 N 维度 | 加载 A | 加载 B |
|---|---|---|---|---|
| gemm_nt | stride=1 (连续) | stride=1 (连续) | uint128_t 合并 ✅ | uint128_t 合并 ✅ |
| gemm_nn | stride=1 (连续) | stride=ldB (不连续) | uint128_t 合并 ✅ | 逐元素拷贝 ❌ |
| gemm_tn | stride=ldA (不连续) | stride=ldB (不连续) | 逐元素 ❌ | 逐元素 ❌ |
| gemm_tt | stride=ldA (不连续) | stride=1 (连续) | 逐元素 ❌ | uint128_t 合并 ✅ |
Enjoy Reading This Article?
Here are some more articles you might like to read next: