CUDA 笔记集合

1. cutlass/CuTe GEMM 中矩阵的存储方式 NT / TN / NN / TT

1.1. 背景

BLAS 的约定是:所有矩阵一律按 column-major 存储,然后用 transA/transB 标志告诉 BLAS 要不要对它做转置:

\[C = \alpha \cdot op(A) \cdot op(B) + \beta \cdot C\]

其中:

  • transXN时:$op(X) = X$,当transXT时:$op(X) = X^T$。
  • 乘法要求 $op(A)$ 是 $M \times K$,$op(B)$ 是 $K \times N$。

1.2. GEMM 命名含义

CuTe给矩阵做了一个约定:A(M, K),B(N, K),C(M, N),即:

  • A 矩阵:(M,K) – M 行 K 列
  • B 矩阵:(N,K) – N 行 K 列(不同于 BLAS 及其他典型约定)
  • C 矩阵:(M,N) – M 行 N 列

即 CuTe 对 B 的约定,默认即为转置形式,即$B^T$,正好与 BLAS 约定形成转置关系。

由于 CuTe 中对 A/B/C 的约束,导致在调用 BLAS 的时候,通过设置主序来表达转置关系:

  • 针对 A,如果是 N,则使用 CuTe 表示的时候,A 是 (M, K),column-major;如果是 T,则使用 CuTe 表示的时候,A 是 (M, K),row-major。
  • 针对 B,如果是 N,则使用 CuTe 表示的时候,B 是 (N, K),row-major;如果是 T,则使用 CuTe 表示的时候,B 是 (N, K),column-major。
  • 即 B 在 CuTe 中的表示,与 A 在 CuTe 中的表示,转换规律正好相反。

CuTe 对 A/B/C 的约定,参见https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/0x_gemm_tutorial.md#the-full-tensors-shapes-strides-and-data

gemm_nt

template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_nt(int m, int n, int k, Alpha alpha, TA const* A,
                           int ldA, TB const* B, int ldB, Beta beta, TC* C,
                           int ldC, cudaStream_t stream)

nt含义是:

  • 此时,BLAS 约定 $A$ 是 (M, K),column-major。此时 CuTe 约定格式与 BLAS 约定一致。
  • 此时,BLAS 的约定 $B^T$ 是 (K, N)。正好与 CuTe 约定的 B 矩阵形状一致,且是 column-major。

最终得到:

strider_A = cute::make_stride(cute::_1, ldA); // column-major
shape_A = cute::make_shape(M, K);

stride_B = cute::make_stride(cute::_1, ldB); // column-major
shape_B = cute::make_shape(N, K);

stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);

gemm_tn

template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_tn(int m, int n, int k, Alpha alpha, TA const* A,
                           int ldA, TB const* B, int ldB, Beta beta, TC* C,
                           int ldC, cudaStream_t stream)
  • $A^T$是(M, K)。按照 CuTe 约定,CuTe 要表达 $A^T$,只能是 (M, K) + row-major。
  • 由于 CuTe 对 B 的约定导致其存储格式与 BLAS 约定形成转置关系,B 的的存储格式是 (N, K) + row-major。

得到:

stride_A = cute::make_stride(ldA, cute::_1); // row-major
shape_A = cute::make_shape(M, K);

stride_B = cute::make_stride(ldB, cute::_1); // row-major
shape_B = cute::make_shape(N, K);

stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);

gemm_nn

template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_nn(int m, int n, int k, Alpha alpha, TA const* A,
                           int ldA, TB const* B, int ldB, Beta beta, TC* C,
                           int ldC, cudaStream_t stream)
  • $A$ 是 (M, K),column-major。CuTe 约定格式与 BLAS 约定一致。
  • $B$ 是 (N, K),row-major。由于 CuTe 对 B 的约定导致其存储格式与 BLAS 约定形成转置关系,B 的的存储格式是 (N, K) + row-major。

得到:

stride_A = cute::make_stride(cute::_1, ldA); // column-major
shape_A = cute::make_shape(M, K);

stride_B = cute::make_stride(ldB, cute::_1); // row-major
shape_B = cute::make_shape(N, K);

stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);

gemm_tt

template <class TA, class TB, class TC, class Alpha, class Beta>
static cudaError_t gemm_tt(int m, int n, int k, Alpha alpha, TA const* A,
                           int ldA, TB const* B, int ldB, Beta beta, TC* C,
                           int ldC, cudaStream_t stream)
  • $A^T$ 是 (M, K),row-major。CuTe 约定格式与 BLAS 约定形成转置关系,因此 A 的存储格式是 (M, K) + row-major。
  • $B^T$ 是 (K, N),column-major。CuTe 约定格式与 BLAS 约定一致,即 CuTe 默认表达 BLAS 的转置形式。

得到:

stride_A = cute::make_stride(ldA, cute::_1); // row-major
shape_A = cute::make_shape(M, K);

stride_B = cute::make_stride(cute::_1, ldB); // column-major
shape_B = cute::make_shape(N, K);

stride_C = cute::make_stride(cute::_1, ldC); // column-major
shape_C = cute::make_shape(M, N);

1.3. 内存访问效率分析–访存合并

在划分 A/B 的的过程中,一般按照 M 方向划分 tile(针对 A),或者按照 N 方向划分 tile(针对 B)。比如如下 Thread-Value Layout 划分:

// Define thread layouts.
auto const thread_shape_A{cute::make_shape(cute::Int<16>{}, cute::Int<8>{})}; // (THR_M, THR_K)
auto const thread_shape_B{cute::make_shape(cute::Int<16>{}, cute::Int<8>{})}; // (THR_N, THR_K)
auto const thread_shape_C{cute::make_shape(cute::Int<32>{}, cute::Int<4>{})}; // (THR_M, THR_N)
auto const thread_stride_A{cute::make_stride(cute::Int<1>{}, cute::size<0>(thread_shape_A))}; // column-major
auto const thread_stride_B{cute::make_stride(cute::Int<1>{}, cute::size<0>(thread_shape_B))}; // column-major
auto const thread_stride_C{cute::make_stride(cute::Int<1>{}, cute::size<0>(thread_shape_C))}; // column-major
auto const thread_layout_A{cute::make_layout(thread_shape_A, thread_stride_A)}; // (THR_M, THR_K)
auto const thread_layout_B{cute::make_layout(thread_shape_B, thread_stride_B)}; // (THR_N, THR_K)
auto const thread_layout_C{cute::make_layout(thread_shape_C, thread_stride_C)}; // (THR_M, THR_N)

此时,得到:$\text{thr_id} = m \times 1 + k \times 16$。

即,同一个 warp 内的 32 个连续线程,它们的 thread ID 沿第一个维度(M 或 N)连续变化。所以:

加载 A tile 时:warp 内线程沿 M 维度连续 → 如果 A 在 M 维度内存连续(column-major, stride=1),就是合并访存 ✅ 加载 B tile 时:warp 内线程沿 N 维度连续 → 如果 B 在 N 维度内存连续(column-major, stride=1),就是合并访存 ✅

对这几种形式的 GEMM,可以得到其访存能否合并:

变体 A 的 M 维度 B 的 N 维度 加载 A 加载 B
gemm_nt stride=1 (连续) stride=1 (连续) uint128_t 合并 ✅ uint128_t 合并 ✅
gemm_nn stride=1 (连续) stride=ldB (不连续) uint128_t 合并 ✅ 逐元素拷贝 ❌
gemm_tn stride=ldA (不连续) stride=ldB (不连续) 逐元素 ❌ 逐元素 ❌
gemm_tt stride=ldA (不连续) stride=1 (连续) 逐元素 ❌ uint128_t 合并 ✅

参考代码来源:https://github.com/leimao/CUTLASS-Examples/blob/main/examples/cute_general_matrix_multiplication/cute_general_matrix_multiplication_tensor_core_gmem_tiled_copy_smem_tiled_copy_tiled_mma_sm80_pipeline.cu




    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • al-folio 模板定制修改总结
  • al-folio 本地部署记录(Ubuntu 24.04)
  • C++ Traits
  • 道格拉斯-普克算法(Douglas–Peucker algorithm)
  • CMake支持库收集