文章

CUTLASS-Cute 初步(3.1):TiledCopy 以及 TiledMMA 配置示例

CUTLASS-Cute 初步(3.1):TiledCopy 以及 TiledMMA 配置示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  // Configure data type.
  using TA = cute::half_t;
  using TB = cute::half_t;
  using TC = cute::half_t;

  // Configure static "shared memory".
  // The "shared memory" is actually on host for preview purpose.
  // For tiled mma, the shared memory layout has to be static.
  constexpr int bM{128 * 2 / sizeof(TA)};
  constexpr int bN{128 * 2 / sizeof(TB)};
  constexpr int bK{32};
  auto const blk_M = cute::Int<bM>{};
  auto const blk_N = cute::Int<bN>{};
  auto const blk_K = cute::Int<bK>{};

  auto const smem_shape_A{cute::make_shape(blk_M, blk_K)};
  auto const smem_shape_B{cute::make_shape(blk_N, blk_K)};
  auto const smem_shape_C{cute::make_shape(blk_M, blk_N)};
  auto const smem_stride_A{cute::make_stride(cute::Int<1>{}, blk_M)};        // Column-major
  auto const smem_stride_B{cute::make_stride(cute::Int<1>{}, blk_N)};        // Column-major
  auto const smem_stride_C{cute::make_stride(cute::Int<1>{}, blk_M)};        // Column-major
  auto const smem_layout_A{cute::make_layout(smem_shape_A, smem_stride_A)};  // (blk_M, blk_K)
  auto const smem_layout_B{cute::make_layout(smem_shape_B, smem_stride_B)};  // (blk_N, blk_K)
  auto const smem_layout_C{cute::make_layout(smem_shape_C, smem_stride_C)};  // (blk_M, blk_N)

  auto const size_a{blk_M * blk_K};
  auto const size_b{blk_N * blk_K};
  auto const size_c{blk_M * blk_N};

  auto h_A = thrust::host_vector<TA>(size_a);
  auto h_B = thrust::host_vector<TB>(size_b);
  auto h_C = thrust::host_vector<TC>(size_c);

  // Make tensor for smem_A and smem_B.
  auto smem_tensor_A{cute::make_tensor(h_A.data(), smem_layout_A)};
  auto smem_tensor_B{cute::make_tensor(h_B.data(), smem_layout_B)};
  auto smem_tensor_C{cute::make_tensor(h_C.data(), smem_layout_C)};

1. TiledMMA 配置

位于 SMEM 中的 tile 大小为 $M \times N \times K = 128 \times 128 \times 32$,其中:

  • A 矩阵为 $M \times K = 128 \times 32$,row-major layout;
  • B 矩阵为 $K \times N = 32 \times 128$,column-major layout;
  • C 矩阵为 $M \times N = 128 \times 128$。

1.1. MMA_Atom 配置

MMA_Atom 使用的配置为 cute::SM80_16x8x16_F16F16F16F16_TN,使用一个 warp,即 32 个线程处理这个 MMA Atom。处理的 MNK 规模为:$M’ \times N’ \times K’ = 16 \times 8 \times 16$,其中:

  • A sub-tile 为 $M’ \times K’ = 16 \times 16$;
  • B sub-tile 为 $K’ \times N’ = 16 \times 8$;
  • C sub-tile 为 $M’ \times N’ = 16 \times 8$。
1
2
3
4
5
6
7
mma_atom
MMA_Atom
  ThrID:      _32:_1
  Shape_MNK:  (_16,_8,_16)
  LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
  LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
  LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

分配到线程,每个线程处理的元素数量为:A 矩阵为 2 x 2 x 2 = 8 个元素,B 矩阵为 2 x 2 = 4 个元素,得到 C 矩阵中 2 x 2 = 4 个元素。

inverse_tv_layout_SM80_16x8x16_F16F16F16F16_TN

1.2. Tiled MMA 配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
  // Configure tiled MMA.
  using MmaTraits           = cute::MMA_Traits<cute::SM80_16x8x16_F16F16F16F16_TN>;
  using MmaAtomShape        = MmaTraits::Shape_MNK;
  auto const mma_atom       = cute::MMA_Atom<MmaTraits>{};
  auto const mma_atom_shape = MmaAtomShape{};
  // Repeating the mma atom along the M, N, and K dimensions.
  // This increases the number of threads to process the tiled MMA.
  constexpr int MMA_LAYOUT_M{2};
  constexpr int MMA_LAYOUT_N{2};
  constexpr int MMA_LAYOUT_K{1};
  auto mma_layout{cute::make_layout(
    cute::make_shape(cute::Int<MMA_LAYOUT_M>{}, cute::Int<MMA_LAYOUT_N>{}, cute::Int<MMA_LAYOUT_K>{}))};
  // Repeating the mma processing along the M, N, and K dimensions.
  // This does not increase the number of threads to process the tiled MMA.
  // But the number of registers required for processing the tiled MMA increases.
  constexpr int NUM_MMA_TILE_M{1};
  constexpr int NUM_MMA_TILE_N{2};
  constexpr int NUM_MMA_TILE_K{1};
  constexpr int MMA_TILE_M{cute::get<0>(mma_atom_shape) * MMA_LAYOUT_M * NUM_MMA_TILE_M};
  constexpr int MMA_TILE_N{cute::get<1>(mma_atom_shape) * MMA_LAYOUT_N * NUM_MMA_TILE_N};
  constexpr int MMA_TILE_K{cute::get<2>(mma_atom_shape) * MMA_LAYOUT_K * NUM_MMA_TILE_K};
  auto mma_tile{cute::make_tile(cute::Int<MMA_TILE_M>{}, cute::Int<MMA_TILE_N>{}, cute::Int<MMA_TILE_K>{})};
  auto tiled_mma{cute::make_tiled_mma(mma_atom, mma_layout, mma_tile)};

在 M 维度上,MMA Atom 重复 2 次,在 N 维度上重复 2 次,在 K 维度上重复 1 次。一共需要 2 x 2 x 1 = 4 个 MMA Atom 来处理这个 tiled MMA。每个 Atom 由一个 warp(32 个线程)处理,整个 tiled MMA 由 4 个 warp(128 个线程)处理。即 ThrLayoutVMNK = (_32,_2,_2,_1):(_1,_32,_64,_0)。经过此配置后,能处理的 MNK 规模为 $(M’ \times 2) \times (N’ \times 2) \times (K’ \times 1) = 32 \times 16 \times 16$。

另外,通过配置 PermutationMNK(对应以前版本的 ValLayoutMNK),使得一个 tiled MMA 在 M/N/K 方向上处理更多的元素(即一个线程处理更多的元素)。这里配置 N 维度上乘以 2,得到该 tiled MMA 处理的 MNK 规模为 $32 \times 32 \times 16$,即 ** PermutationMNK: (_32,_32,_16)**。

1
2
3
4
5
6
7
8
9
10
tiled_mma
TiledMMA
  ThrLayoutVMNK:  (_32,_2,_2,_1):(_1,_32,_64,_0)
  PermutationMNK: (_32,_32,_16)
MMA_Atom
  ThrID:      _32:_1
  Shape_MNK:  (_16,_8,_16)
  LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
  LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
  LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

tile_mma_SM80_16x8x16_F16F16F16F16_TN

参考及资料

本文由作者按照 CC BY 4.0 进行授权