CUTLASS-Cute 初步(3.1):TiledCopy 以及 TiledMMA 配置示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// Configure data type.
using TA = cute::half_t;
using TB = cute::half_t;
using TC = cute::half_t;
// Configure static "shared memory".
// The "shared memory" is actually on host for preview purpose.
// For tiled mma, the shared memory layout has to be static.
constexpr int bM{128 * 2 / sizeof(TA)};
constexpr int bN{128 * 2 / sizeof(TB)};
constexpr int bK{32};
auto const blk_M = cute::Int<bM>{};
auto const blk_N = cute::Int<bN>{};
auto const blk_K = cute::Int<bK>{};
auto const smem_shape_A{cute::make_shape(blk_M, blk_K)};
auto const smem_shape_B{cute::make_shape(blk_N, blk_K)};
auto const smem_shape_C{cute::make_shape(blk_M, blk_N)};
auto const smem_stride_A{cute::make_stride(cute::Int<1>{}, blk_M)}; // Column-major
auto const smem_stride_B{cute::make_stride(cute::Int<1>{}, blk_N)}; // Column-major
auto const smem_stride_C{cute::make_stride(cute::Int<1>{}, blk_M)}; // Column-major
auto const smem_layout_A{cute::make_layout(smem_shape_A, smem_stride_A)}; // (blk_M, blk_K)
auto const smem_layout_B{cute::make_layout(smem_shape_B, smem_stride_B)}; // (blk_N, blk_K)
auto const smem_layout_C{cute::make_layout(smem_shape_C, smem_stride_C)}; // (blk_M, blk_N)
auto const size_a{blk_M * blk_K};
auto const size_b{blk_N * blk_K};
auto const size_c{blk_M * blk_N};
auto h_A = thrust::host_vector<TA>(size_a);
auto h_B = thrust::host_vector<TB>(size_b);
auto h_C = thrust::host_vector<TC>(size_c);
// Make tensor for smem_A and smem_B.
auto smem_tensor_A{cute::make_tensor(h_A.data(), smem_layout_A)};
auto smem_tensor_B{cute::make_tensor(h_B.data(), smem_layout_B)};
auto smem_tensor_C{cute::make_tensor(h_C.data(), smem_layout_C)};
1. TiledMMA 配置
位于 SMEM 中的 tile 大小为 $M \times N \times K = 128 \times 128 \times 32$,其中:
- A 矩阵为 $M \times K = 128 \times 32$,row-major layout;
- B 矩阵为 $K \times N = 32 \times 128$,column-major layout;
- C 矩阵为 $M \times N = 128 \times 128$。
1.1. MMA_Atom 配置
MMA_Atom 使用的配置为 cute::SM80_16x8x16_F16F16F16F16_TN,使用一个 warp,即 32 个线程处理这个 MMA Atom。处理的 MNK 规模为:$M’ \times N’ \times K’ = 16 \times 8 \times 16$,其中:
- A sub-tile 为 $M’ \times K’ = 16 \times 16$;
- B sub-tile 为 $K’ \times N’ = 16 \times 8$;
- C sub-tile 为 $M’ \times N’ = 16 \times 8$。
1
2
3
4
5
6
7
mma_atom
MMA_Atom
ThrID: _32:_1
Shape_MNK: (_16,_8,_16)
LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))
分配到线程,每个线程处理的元素数量为:A 矩阵为 2 x 2 x 2 = 8 个元素,B 矩阵为 2 x 2 = 4 个元素,得到 C 矩阵中 2 x 2 = 4 个元素。
1.2. Tiled MMA 配置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// Configure tiled MMA.
using MmaTraits = cute::MMA_Traits<cute::SM80_16x8x16_F16F16F16F16_TN>;
using MmaAtomShape = MmaTraits::Shape_MNK;
auto const mma_atom = cute::MMA_Atom<MmaTraits>{};
auto const mma_atom_shape = MmaAtomShape{};
// Repeating the mma atom along the M, N, and K dimensions.
// This increases the number of threads to process the tiled MMA.
constexpr int MMA_LAYOUT_M{2};
constexpr int MMA_LAYOUT_N{2};
constexpr int MMA_LAYOUT_K{1};
auto mma_layout{cute::make_layout(
cute::make_shape(cute::Int<MMA_LAYOUT_M>{}, cute::Int<MMA_LAYOUT_N>{}, cute::Int<MMA_LAYOUT_K>{}))};
// Repeating the mma processing along the M, N, and K dimensions.
// This does not increase the number of threads to process the tiled MMA.
// But the number of registers required for processing the tiled MMA increases.
constexpr int NUM_MMA_TILE_M{1};
constexpr int NUM_MMA_TILE_N{2};
constexpr int NUM_MMA_TILE_K{1};
constexpr int MMA_TILE_M{cute::get<0>(mma_atom_shape) * MMA_LAYOUT_M * NUM_MMA_TILE_M};
constexpr int MMA_TILE_N{cute::get<1>(mma_atom_shape) * MMA_LAYOUT_N * NUM_MMA_TILE_N};
constexpr int MMA_TILE_K{cute::get<2>(mma_atom_shape) * MMA_LAYOUT_K * NUM_MMA_TILE_K};
auto mma_tile{cute::make_tile(cute::Int<MMA_TILE_M>{}, cute::Int<MMA_TILE_N>{}, cute::Int<MMA_TILE_K>{})};
auto tiled_mma{cute::make_tiled_mma(mma_atom, mma_layout, mma_tile)};
在 M 维度上,MMA Atom 重复 2 次,在 N 维度上重复 2 次,在 K 维度上重复 1 次。一共需要 2 x 2 x 1 = 4 个 MMA Atom 来处理这个 tiled MMA。每个 Atom 由一个 warp(32 个线程)处理,整个 tiled MMA 由 4 个 warp(128 个线程)处理。即 ThrLayoutVMNK = (_32,_2,_2,_1):(_1,_32,_64,_0)。经过此配置后,能处理的 MNK 规模为 $(M’ \times 2) \times (N’ \times 2) \times (K’ \times 1) = 32 \times 16 \times 16$。
另外,通过配置 PermutationMNK(对应以前版本的 ValLayoutMNK),使得一个 tiled MMA 在 M/N/K 方向上处理更多的元素(即一个线程处理更多的元素)。这里配置 N 维度上乘以 2,得到该 tiled MMA 处理的 MNK 规模为 $32 \times 32 \times 16$,即 ** PermutationMNK: (_32,_32,_16)**。
1
2
3
4
5
6
7
8
9
10
tiled_mma
TiledMMA
ThrLayoutVMNK: (_32,_2,_2,_1):(_1,_32,_64,_0)
PermutationMNK: (_32,_32,_16)
MMA_Atom
ThrID: _32:_1
Shape_MNK: (_16,_8,_16)
LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))
参考及资料
- CuTe Tiled MMA:Mao Lei 博客