CUTLASS-Cute 初步(4):Swizzle

Swizzle作用于SMEMlayout。给定layout范围内,Swizzle通过列异或操作(icol = irow ^ icol),周期性的coord重排,映射到新的物理地址offsetSwizzle定义了三个参数:

  • $MBase$:以 $2^M$ 个一维坐标连续的元素为单位,将其当做一个元素;
  • $SShift$:从Offset中提取的高位偏移,用于提取Offsetlead dimension
  • $BBits$:参与XOR的位数,用于提取一维index中的lead dimensionfast dimension中的部分bits。

引用reed解释及图示,其输入为一个一维坐标的layout,通过swizzle将其拆分为二维坐标表示形式:

swizzle 逻辑示意

给定义一个输入Offset:<LeadBits:FastBits>

  • 提取关系为:BBits+MBase -> FastBitsSShift+MBase -> LeadBits
  • 参与XOR操作的位宽为:BBits,即LeadBits中的低BBits位,FastBits中的高BBits位。

BBits表示有$2^B$个交换模式,SShift表示交换模式的周期。通常$\mid{S}\mid \ge B$,如果$\mid{S}\mid \gt B$,则此交换模式重复$2^{\mid{S}\mid - B}$次,如果$\mid{S}\mid = B$,则只套用一次此交换模式。

一般在设置Swizzle参数时,按输入的layout一行(准确的说是fast dimension)为周期进行swizzle,$2^{S+M}$ = 输入layout的列长度(此处仅指逻辑上的,实际完整的计算公式还需要考虑到元素存储字节数,具体见下面章Swizzle 参数设计规则)。

比如 half 类型的layout (8, 32):(32, 1),定义swizzle<3, 3, 3>,即 8 个元素形成新的最小单位(M),8 个最小单位为一行(B),所以swizzle从$8 \times 8 = 64$个元素开始。见下面示例。B 为 8,则整个swizzle周期为 8 行。

  • 设计Swizzle参数时,要求S >= B,否则不能提取到LeadBits
  • 设计Layout时,如果Layoutfast dimension长度小于 $2^{S+M}$,Swizzle不能完整的提取LeadBits,导致Swizzle失效或部分失效。
  • 异或操作数学符号为 $\oplus$。

1. Cute Swizzle 示例

定义 layout (8, 32):(32, 1),定义swizzle<3, 2, 3>。定义的Swizzle含义如下:

  • $2^{B+M}=2^{3+2}$ = 32:fast dimension 长度,32个elements。
  • $2^{M}=2^2=4$:4个element组成一个最小单位。
  • $2^{B}=2^3=8$:8个交换模式。
  • $2^{S}=2^3=8$:交换模式周期为8。

代码如下:

from cutlass import cute
from cute_viz import render_layout_svg, render_swizzle_layout_svg

@cute.jit
def test_swizzle_layout():
    layout_2d = cute.make_layout((8, 32), stride=(32, 1))
    sw = cute.make_swizzle(3, 2, 3)
    swizzled_layout = cute.make_composed_layout(sw, 0, layout_2d)
    render_layout_svg(layout_2d, "out/original_layout.svg")
    render_swizzle_layout_svg(swizzled_layout, "out/swizzled_layout.svg")

test_swizzle_layout()

结果如下:

swizzle_8_32_3_3_3

2. Swizzle 逻辑及规律

class Swizzle {
 public:
  Swizzle(int num_bits, int num_base, int num_shft) : m_num_bits(num_bits), m_num_base(num_base), m_num_shft(num_shft) {
    CHECK2(m_num_bits >= 0, "BBits must be positive.");
    CHECK2(m_num_base >= 0, "MBase must be positive.");
    CHECK2(std::abs(m_num_shft) >= m_num_bits, "abs(SShift) must be more than BBits.");
  }

  template <class Offset>
  auto apply(Offset offset) const noexcept {
    return offset ^ shiftr(offset & m_yyy_msk);  // ZZZ ^= YYY
  }

  template <class Offset>
  auto operator()(Offset offset) const noexcept {
    return apply(offset);
  }

 private:
  template <class Offset>
  auto shiftr(Offset offset) const noexcept {
    return m_msk_sft >= 0 ? offset >> m_msk_sft : offset << -m_msk_sft;
  }

  int m_num_bits;
  int m_num_base;
  int m_num_shft;

  int m_bit_msk = (1 << m_num_bits) - 1;
  int m_yyy_msk = m_bit_msk << (m_num_base + std::max(0, m_num_shft));
  int m_zzz_msk = m_bit_msk << (m_num_base - std::min(0, m_num_shft));
  int m_msk_sft = m_num_shft;
};
  • m_zzz_msk没有参与swizzle计算,在CuTe库中用作检查作用。

Swizzle根据参数BBitsSShift,生成offset的掩码:高位部分提取掩码yyy_msk、低位部分提取zzz_msk,即分别对应行提取掩码、列提取掩码。此时,可以将MBase看作是BBits的一部分。直观展示如下:

                 bits       bits
                  --         --
0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
                    <--------->---
                       shift   base

针对每个offset,经过swizzle映射之后,异或更新低位掩码对应的值:

0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx
其中 AA = ZZ ^ YY。

2.1. Swizzle 参数影响规律分析

分别以行混淆周期,以及列混淆周期,这两个层次来分析。

以一个 layout (32, 16):(16, 1)为例,分析swizzle<B=4, M=0, S4> 参数变动对结果的影响规律。

经过offset & yyy_mk提取行号低四位(以及shiftr操作得到最终行号提取掩码),得到第 0 行、第 16 行(0x10)由于行号掩码提取过后的低 4 位为 0,导致swizzle无效。

2.1.1. B 参数对周期的影响

如果使用 swizzle<B=3, M=0, S=4>,导致高位掩码提取行号的低三位,行号为 0、8、16、24 时,得到的高位 YY 部分均为 0,swizzle 异或操作不生效。

如果设 S = 3,则layout的左半部分(列 0 ~ 8)呈现 0、8、16、24 的行混淆周期。右半部分暂没有理清规律。

2.1.2. S 参数以及列长度对周期的影响

如果使用 S = 5,行号提取范围扩大到 32,由于 layout 的行号、列号范围对应掩码为 4 位,导致的结果为行号有效的掩码位为 0bxxxx0 >> 1,即最低一位被丢弃,且有效的掩码位为 4 位(注意 shiftr 的实现是提取高位之后右移S位)。最终的规律为:第 0、1 行维持不变,行周期变为 32 行,即第32、33行维持不变。中间的行,则每两行异或计算结果一致,即如果应用于解决 bank conflicts,此时只能消除一半的 bank conflicts。

2.2. 测试代码

from cutlass import cute
from cute_viz import render_swizzle_layout_svg

@cute.jit
def test_swizzle_layout():
    layout_2d = cute.make_layout((32, 16), stride=(16, 1))
    sw = cute.make_swizzle(4, 0, 4)
    swizzled_layout = cute.make_composed_layout(sw, 0, layout_2d)
    # render_layout_svg(layout_2d, "out/original_layout.svg")
    render_swizzle_layout_svg(swizzled_layout, "out/swizzled_layout.svg")

test_swizzle_layout()

2.3. Swizzle 与 Layout 的关系

Swizzle 的实现与 Layout 没有关系,但是从其实现看,最终形成的 index,fast dimension (即内存连续的维度)位于低位,得到的结果就是对 fast dimension 形成XOR操作,即改变其映射顺序。演示代码片段如下:

  for (int i = 0; i < cute::size<0>(layout); i++) {
    for (int j = 0; j < cute::size<1>(layout); j++) {
      int idx          = layout(i, j);
      int swizzled_idx = swizzle(idx);
      int bank_id      = (swizzled_idx * element_size / 4) % 32;
    }
  }

创建一个列主序的 Layout (32, 32):(1, 32)XOR此时作用在行号上(即 fast dimension),但是从物理存储上看,其结果与行主序的 layout 是相同的。这个列主序 Layout 原始图示,以及 Swizzle 映射之后图示如下:

origin_column_major_32_32_505 swizzle_32_32_505

另外,Swizzle就是一个复合映射,演示代码如下:

  constexpr auto smem_atom_layout_A = cute::make_layout(cute::make_shape(cute::Int<32>{}, cute::Int<8>{}));
  constexpr auto smem_atom_layout_A_swizzled = cute::composition(swizzle_A, smem_atom_layout_A);
  constexpr auto smem_shape_A  = cute::make_shape(bM, bK, bP);  // (bM, bK, bP)
  constexpr auto smem_layout_A = cute::tile_to_shape(smem_atom_layout_A_swizzled, smem_shape_A);

3. Swizzle 参数设计规则

假定矩阵每个元素大小为 S-byte,向量化访问的宽度为N个元素,shared memoryfast dimensionX个元素。即:

符号 含义
S_elem 每个元素的大小(bytes)
N 向量化访问的元素个数
X Fast dimension 的元素个数

$\text{MBase} = log_{2}\text{N}$,即向量化访问的元素个数。

$\text{SShift} = log_{2}\text{X} - \text{MBase}$,即 X = $2^{\text{MBase} + \text{SShift}}$,这样使得针对每一行的 swizzle 操作,掩码偏移对齐到行号的位置。即,将提取行号的掩码分为两部分:低 MBase 位不参与 swizzle,高 SShift 位参与 swizzle。

$\text{BBits} = log_{2}\text{(32 * 4 / S)} - \text{MBase}$。其原因为要确保 BBits 对应覆盖一次 shared memory 的访问字宽:128 字节,即 $\text{S} \times 2^\text{MBase + BBits}$ = 32 * 4 = 128B。

即要求 32 个连续的 word 地址,经过 swizzle 之后,分别落入 shared memory32bank中。

3.1. Swizzle 参数设计示例1

假定 half 类型(S_elem = 2 bytes)行主序矩阵,矩阵大小为 8 * 64。采用 128-bit 向量化访问指令,即每次访问 8 个 half 元素(4 个 word)。

MBase = log2(8) = 3。

SShift = log2(64) - MBase = 6 - 3 = 3。即其约束在于 shared memoryfast dimension 为 64 个元素,即掩码偏移量为 fast dimension 长度。

BBits = log2(32 * 2) - MBase = 6 - 3 = 3。即要求 32 个连续的 word 地址,经过 swizzle 之后,分别落入 shared memory 的32个 bank 中。

最终,得到 swizzle<3, 3, 3>

3.2. 一个错误的设计示例,以及修复

输入 layout,其 Offset 组成为:<LeadBits>:<FastBits>。Swizzle 对输入 layout 的约束为:

  • 输入 layout 的 fast dimension 长度,即连续内存访问的维度长度,应该等于 $2^{MBase + SShift}$。
  • 如果输入 layout fast dimension 长度小于 $2^{MBase + SShift}$,则导致 Swizzle 只取到LeadBits的高位部分,低位部分取到FastBits里面去了。
  • 输入 layout 的 lead dimension 长度小于 $2^{MBase + SShift}$,没有影响,即一个交换模式周期没有被完整应用。

错误示例如下:

  constexpr auto bM = cute::Int<128>{};
  constexpr auto bK = cute::Int<32>{};

  constexpr auto MBase_A   = constexpr_log2(sizeof(cute::uint128_t) / sizeof(T));  // log2(8) = 3
  constexpr auto BBits_A   = constexpr_log2(32 * 4 / sizeof(T)) - MBase_A;         // log2(64) - 3 =  3
  constexpr auto SShift_A  = constexpr_log2(bM) - MBase_A;                         // log2(128) - 3 = 4
  constexpr auto swizzle_A = cute::Swizzle<BBits_A, MBase_A, SShift_A>{};

  constexpr auto smem_atom_layout_A          = cute::make_layout(cute::make_shape(cute::Int<32>{}, cute::Int<8>{}));
  constexpr auto smem_atom_layout_A_swizzled = cute::composition(swizzle_A, smem_atom_layout_A);
  constexpr auto smem_layout_A               = cute::tile_to_shape(smem_atom_layout_A_swizzled, cute::make_shape(bM, bK));

上述示例中,输入 layout 的 fast dimension=32 => 位宽为 5;LeadBits= 8 => 位宽为 3。得到 Offset 位格式为<LeadBits=3:FastBits=5>。与 Swizzle 对齐的位宽组成要求为<X:4+3>,导致 Swizzle 只从lead dimension的高位部分获取2-Bit,低位2-Bit是从Offset的fast dimension中获取的;另外,经过BBits位与之后,LeadBits只有1-Bit起作用。

其结果就是,Swizzle 之后的结果,shape 等于输入 layout 的 shape,但是没有达到预期的效果。测试代码:

@cute.jit
def test_swizzle_layout2():  # 该swizzle不起作用
    """该swizzle不起作用, 因为fast dimension长度不满足要求"""
    layout_3d = cute.make_layout((32, 8))  # 列主序布局
    sw = cute.make_swizzle(3, 3, 4)  # 要求fast dimension 长度为 2^(4+3) = 128
    swizzled_layout = cute.make_composed_layout(sw, 0, layout_3d)
    render_layout_svg(layout_3d, "out/original_layout2.svg")
    render_swizzle_layout_svg(swizzled_layout, "out/swizzled_layout2.svg")

test_swizzle_layout2()

修复代码如下:

  // constexpr auto bM = cute::Int<128>{};
  constexpr auto smem_atom_layout_A = cute::make_layout(cute::make_shape(bM, cute::Int<8>{}));

对应的测试代码:

@cute.jit
def test_swizzle_layout3():
    """修正: fast dimension长度满足要求, swizzle生效"""
    """MBase=3: 8个元素组成一组. """
    layout_3d = cute.make_layout((128, 8))  # 列主序布局
    sw = cute.make_swizzle(3, 3, 4)  # 要求fast dimension 长度为 2^(4+3) = 128
    swizzled_layout = cute.make_composed_layout(sw, 0, layout_3d)
    render_layout_svg(layout_3d, "out/original_layout3.svg")
    render_swizzle_layout_svg(swizzled_layout, "out/swizzled_layout3.svg")

test_swizzle_layout3()

Swizzle之后的Layout:

swizzle_128_8_3_3_4

相关测试代码:https://github.com/HPC02/cuda_perf/blob/master/src/study_codes/test_gmem_smem_swizzle/test_gemm_smem_swizzle.cu

4. Thread Block Swizzle

由于 CUDA 调度实际上是以 block id 的顺序进行调度的,有时候,通过对 thread block 映射的tile进行重新排序,最大限度的利用L2 cache中个的数据,提升性能。

thread block swizzle 逻辑示意

GPU 硬件调度器通常优先调度 x,然后 y,最后 z 维度(单个SM调度thread block?还是所有SM -> thread block的顺序?)。

Thread block swizzle,以及 block 调度顺序,见 github issue:[QST]how to understand “block swizzling”,以及博客Nvidia Tensor Core-CUDA HGEMM Advanced Optimization

关于使用 thread block swizzle 复用 L2 cache,见 NVIDIA 博客Optimizing Compute Shaders for L2 Locality using Thread-Group ID Swizzling

A. 参考资料

A.1. 其他资料




    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • NVIDIA GPU 架构:SP、SM 与 LSU 工作原理详解
  • al-folio 模板定制修改总结
  • al-folio 本地部署记录(Ubuntu 24.04)
  • C++ Traits
  • 道格拉斯-普克算法(Douglas–Peucker algorithm)