C++ Traits

C++ Traits 是一种元编程技术,将类型信息封装在一个类中,供算法或其他模板使用,以期达到从算法/逻辑中分离类型信息的目的。比如元素类型、一些常量定义、指令选择、对齐要求等等。

问题 Traits 如何解决
内置类型(int, float)不能添加成员 Traits 外挂信息
不同类型需要不同的算法策略 模板特化选择策略
想要零开销的编译期多态 全部在编译期决议
接口与实现解耦 算法只依赖 Traits 接口

1. 自定义 Traits 常见模式

1.1 类型信息萃取

目的:从不同类型中统一提取需要的信息。

// 主模板 (可以留空或给默认值)
template <typename T>
struct NumericTraits;

// 对 float 特化
template <>
struct NumericTraits<float> {
    using type          = float;
    using compute_type  = float;      // 计算时用的类型
    using accum_type    = float;      // 累加器类型
    static constexpr int bits = 32;
    static constexpr float epsilon = 1e-7f;
    static constexpr float max_val = 3.4028235e+38f;
};

// 对 half 特化
template <>
struct NumericTraits<__half> {
    using type          = __half;
    using compute_type  = float;      // half 计算时提升到 float
    using accum_type    = float;
    static constexpr int bits = 16;
};

// 对 int8_t 特化
template <>
struct NumericTraits<int8_t> {
    using type          = int8_t;
    using compute_type  = int32_t;    // int8 用 int32 累加
    using accum_type    = int32_t;
    static constexpr int bits = 8;
};

// --- 使用 ---
template <typename T>
void gemm_kernel() {
    using Traits = NumericTraits<T>;
    using AccumType = typename Traits::accum_type;

    AccumType accumulator = 0;  // 自动选择合适的累加器类型
    // ...
}

1.2. 策略选择

目的:根据类型选择不同的算法实现。

// 内存拷贝策略 Traits
template <typename T, int Size>
struct CopyTraits;

// 小数据: 逐元素拷贝
template <typename T>
struct CopyTraits<T, 1> {
    static void copy(T* dst, const T* src) {
        *dst = *src;
    }
};

// 4字节对齐: 用 uint32_t 一次拷贝
template <>
struct CopyTraits<float, 4> {
    static void copy(float* dst, const float* src) {
        // 使用向量化加载
        *reinterpret_cast<float4*>(dst) = *reinterpret_cast<const float4*>(src);
    }
};

// 128位对齐: 用 LDG.128 指令
template <>
struct CopyTraits<__half, 8> {
    static void copy(__half* dst, const __half* src) {
        // 8 个 half = 128 bit, 用 128-bit load
        uint4 tmp = *reinterpret_cast<const uint4*>(src);
        *reinterpret_cast<uint4*>(dst) = tmp;
    }
};

1.3. 特征探测

目的:检测类型是否具有某些特征,以启用/禁用特定代码路径。

// 检测类型是否有 .size() 方法
template <typename T, typename = void>
struct has_size_method : std::false_type {};

template <typename T>
struct has_size_method<T, std::void_t<decltype(std::declval<T>().size())>>
    : std::true_type {};

// --- 使用 ---
static_assert(has_size_method<std::vector<int>>::value, "");  // true
static_assert(!has_size_method<int>::value, "");               // true

// 检测类型是否有嵌套类型 value_type
template <typename T, typename = void>
struct has_value_type : std::false_type {};

template <typename T>
struct has_value_type<T, std::void_t<typename T::value_type>>
    : std::true_type {};

1.4. 标签分发

目的:通过标签类型分发不同的实现。

// Tags
struct RowMajorTag {};
struct ColMajorTag {};

// Layout Traits
template <typename Layout>
struct LayoutTraits;

template <>
struct LayoutTraits<RowMajorTag> {
    static int index(int row, int col, int ldim) {
        return row * ldim + col;
    }
    static constexpr char name[] = "RowMajor";
};

template <>
struct LayoutTraits<ColMajorTag> {
    static int index(int row, int col, int ldim) {
        return col * ldim + row;
    }
    static constexpr char name[] = "ColMajor";
};

// 算法根据 tag 自动适配
template <typename LayoutTag>
void fill_matrix(float* data, int M, int N, int ld) {
    using LTraits = LayoutTraits<LayoutTag>;
    for (int i = 0; i < M; ++i)
        for (int j = 0; j < N; ++j)
            data[LTraits::index(i, j, ld)] = static_cast<float>(i * N + j);
}

2. Traits 核心技术手段

2.1. 特化/偏特化

  • 完全特化(Template Specialization):为特定类型提供完整实现。
  • 偏特化(Partial Specialization):为满足某些条件的类型提供实现(如指针类型、数组类型等)。

模板特化例子:

template<typename T>
struct my_is_void {
    static const bool value = false;
};

template<>
struct my_is_void<void> {
    static const bool value = true;
};

偏特化例子:

template<typename T>
struct is_pointer {
    static const bool value = false;
};

template<typename T>
struct is_pointer<T*> {
    static const bool value = true;
};

2.2. typename

在 Traits 中经常需要定义嵌套类型(如 typecompute_type 等),使用 typename显式告诉编译器该名字是一个类型

template <typename T>
struct NumericTraits {
    using type = T;  // 定义一个嵌套类型
};

template <typename T>
void foo() {
    using MyType = typename NumericTraits<T>::type;  // 使用嵌套类型
    // ...
}

2.3. SFINAE(Substitution Failure Is Not An Error)

在 SFINAE 中,如果某个表达式无效(如访问了不存在的成员),编译器会将该特化视为无效,而不是报错,从而允许其他特化继续匹配。

在 SFINAE 中,编译器优先特化/偏特化版本,C++ 偏特化规则:更特化(更具体)的版本优先。最后才是主模板版本。

2.3.1. void_t 和 SFINAE 结合实现特征探测

// void_t 是 SFINAE 的瑞士军刀
template <typename... Ts>
using void_t = void;   // C++17 标准已有

// 主模板:利用 void_t 探测嵌套类型
template <typename T, typename = void>
struct element_type_of {
    using type = T;  // 默认: 类型本身
};

// 偏特化:如果 T 有 element_type 嵌套类型,就用它
template <typename T>
struct element_type_of<T, void_t<typename T::element_type>> {
    using type = typename T::element_type;  // 如果有 element_type 就用它
};

// 使用
// element_type_of<std::unique_ptr<int>>::type  → 得到 int
// element_type_of<float>::type              → 得到 float

void_t是一个通用的的类型萃取容器,能将任意数量的表达式包裹为 void 类型。void_t<X> 本身永远是 void,但若 X 是非法类型表达式,模板替换就会失败。

上述代码片段中,以 float 为例,偏特化得到:

element_type_of< float,  void_t<typename float::element_type> >
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                float 没有 ::element_type 成员
                                 类型表达式非法  替换失败

以 std::unique_ptr<int> 为例,偏特化得到:

void_t<typename unique_ptr<int>::element_type>
     = void_t<int>               ::element_type = int,合法!
     = void                      void_t 永远返回 void

最终得到:

element_type_of< unique_ptr<int>,  void >

2.3.2. SFINAE 结合 enable_if 实现函数重载

// enable_if 的原理:
template <bool Cond, typename T = void>
struct enable_if {};             // 默认: 没有 type 成员

template <typename T>
struct enable_if<true, T> {
    using type = T;              // 条件为 true 时才有 type
};

// 使用: 条件不满足 → enable_if<false>::type 不存在 → SFINAE
template <typename T>
typename std::enable_if<std::is_integral<T>::value, T>::type
double_it(T x) {
    return x * 2;
}

template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, T>::type
double_it(T x) {
    return x * 2.0;
}

double_it(42);    // is_integral = true  → 第一个 ✅
double_it(3.14);  // is_floating = true  → 第二个 ✅

2.4. constexpr 与 static 编译期常量

template <typename T>
struct AlignmentTraits {
    // 编译期常量
    static constexpr int alignment = alignof(T);
    static constexpr int size_bytes = sizeof(T);

    // 编译期函数 (C++14)
    static constexpr int max_vector_width() {
        if constexpr (sizeof(T) == 1) return 16;  // 128-bit / 8-bit
        if constexpr (sizeof(T) == 2) return 8;   // 128-bit / 16-bit
        if constexpr (sizeof(T) == 4) return 4;   // 128-bit / 32-bit
        return 1;
    }
};

2.5. 嵌套类型别名

template <typename T>
struct AlignmentTraits {
    // 编译期常量
    static constexpr int alignment = alignof(T);
    static constexpr int size_bytes = sizeof(T);

    // 编译期函数 (C++14)
    static constexpr int max_vector_width() {
        if constexpr (sizeof(T) == 1) return 16;  // 128-bit / 8-bit
        if constexpr (sizeof(T) == 2) return 8;   // 128-bit / 16-bit
        if constexpr (sizeof(T) == 4) return 4;   // 128-bit / 32-bit
        return 1;
    }
};

3. CuTe 中使用 Traits 的例子

SM80_CP_ASYNC_CACHEALWAYS为例,其 Traits 定义如下:

// CuTe 风格的 Copy_Traits (简化示意)
// 对应 cp.async 指令的 Traits
template <typename CopyInstr>
struct Copy_Traits;

// 针对 SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t> 的特化
template <>
struct Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>> {
    // ---- 关联类型 ----

    // 指令操作的寄存器类型
    //   源: 1个 gmem 指针 (uint128_t)
    //   目标: 1个 smem 指针 (uint128_t)
    using SRegisters = cute::tuple<uint128_t const*>;  // source (gmem ptr)
    using DRegisters = cute::tuple<uint128_t*>;        // dest (smem ptr)

    // Layout 描述: 一个线程搬运多少数据、如何排布
    // ThrID: 哪些线程参与
    // ValLayoutSrc / ValLayoutDst: 值的布局
    using ThrID        = Layout<_1>;          // 1个线程
    using ValLayoutSrc = Layout<_1>;          // 搬运1个128-bit值
    using ValLayoutDst = Layout<_1>;

    // ---- 核心操作 ----
    template <typename TS, typename TD>
    CUTE_HOST_DEVICE static void copy(TS const& src, TD& dst) {
        // 调用 PTX 内联汇编
        // cp.async.ca.shared.global [dst], [src], 16;
        cute::cp_async<16>(dst, src);
    }
};

Copy_Atom 如何使用 Traits

// 简化版 Copy_Atom
template <typename Traits, typename Element>
struct Copy_Atom {
    using TraitsType = Traits;
    using ElementType = Element;

    // 从 Traits 中提取信息
    using SrcRegisters = typename Traits::SRegisters;
    using DstRegisters = typename Traits::DRegisters;

    // 执行拷贝 — 委托给 Traits
    template <typename SrcEngine, typename SrcLayout,
              typename DstEngine, typename DstLayout>
    CUTE_HOST_DEVICE
    void copy(Tensor<SrcEngine, SrcLayout> const& src,
              Tensor<DstEngine, DstLayout>      & dst) const
    {
        // 实际调用 Traits::copy
        Traits::copy(src.data(), dst.data());
    }
};

// ===== 用户代码 =====
// 选择 Traits → 决定用什么指令
using CopyTraits = Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>>;

// 创建 Atom
using CopyAtom = Copy_Atom<CopyTraits, cute::half_t>;

// 构建 TiledCopy (多线程协作版本)
using TiledCopy = decltype(
    make_tiled_copy(
        CopyAtom{},
        Layout<Shape<_32, _4>>{},   // 线程布局: 32x4 = 128线程
        Layout<Shape<_1, _8>>{}     // 每线程搬运的值布局
    )
);

5. C++ 标准库中 Traits 相关的常用工具

5.1. 类型判断

5.1.1. 基本类型判断

返回 ::value,bool 类型,C++17 起可用 _v 后缀简写。

td::is_void<T>                  // T 是 void?
std::is_null_pointer<T>          // T 是 std::nullptr_t?       (C++14)
std::is_integral<T>              // T 是整数类型?(bool/char/int/long...)
std::is_floating_point<T>        // T 是浮点?(float/double/long double)
std::is_arithmetic<T>            // T 是算术类型?(整数 或 浮点)
std::is_enum<T>                  // T 是枚举?
std::is_class<T>                 // T 是 class/struct?
std::is_function<T>              // T 是函数类型?
std::is_array<T>                 // T 是数组类型?(T[] 或 T[N])

5.1.2. 复合类型判断

std::is_pointer<T>               // T 是指针?(int*, 但不含成员指针)
std::is_reference<T>             // T 是引用?(左值引用 或 右值引用)
std::is_lvalue_reference<T>      // T 是左值引用?
std::is_rvalue_reference<T>      // T 是右值引用?
std::is_member_pointer<T>        // T 是成员指针?
std::is_const<T>                 // T 有顶层 const?
std::is_volatile<T>              // T 有顶层 volatile?
std::is_signed<T>                // T 是有符号类型?
std::is_unsigned<T>              // T 是无符号类型?

5.1.3. 类型关系判断

// ===== 两个类型的关系 =====

std::is_same<T, U>              // T 和 U 是完全相同的类型?
std::is_base_of<Base, Derived>  // Base 是 Derived 的基类?
std::is_convertible<From, To>   // From 能隐式转换为 To?
std::is_assignable<T, U>        // T = U 赋值合法?
std::is_constructible<T, Args...> // T(Args...) 构造合法?
std::is_invocable<F, Args...>   // F(Args...) 可调用?           (C++17)
std::is_invocable_r<R, F, Args...> // F(Args...) 可调用且返回 R? (C++17)

5.2. 类型变换

5.2.1. 修饰符增删

// ===== 去除修饰 =====
std::remove_const<T>           // const int      → int
std::remove_volatile<T>        // volatile int   → int
std::remove_cv<T>              // const volatile int → int
std::remove_reference<T>       // int&/int&&     → int
std::remove_pointer<T>         // int*           → int
std::remove_extent<T>          // int[10]        → int
std::remove_all_extents<T>     // int[3][4]      → int

// ===== 添加修饰 =====
std::add_const<T>              // int            → const int
std::add_volatile<T>           // int            → volatile int
std::add_cv<T>                 // int            → const volatile int
std::add_lvalue_reference<T>   // int            → int&
std::add_rvalue_reference<T>   // int            → int&&
std::add_pointer<T>            // int            → int*

5.2.2. decay

// decay 模拟"按值传参"时的类型退化:
//   - 去掉引用
//   - 去掉顶层 cv
//   - 数组 → 指针
//   - 函数 → 函数指针

std::decay<const int&>      // → int
std::decay<int[10]>         // → int*
std::decay<int(double)>     // → int(*)(double)
std::decay<const int&&>     // → int

// --- 实用示例: 存储任意传入的值 ---
template <typename T>
struct Storage {
    using StoredType = std::decay_t<T>;
    StoredType value;

    Storage(T&& v) : value(std::forward<T>(v)) {}
};

5.2.3. std::conditional

// ===== conditional: 编译期三目运算符 =====
std::conditional<true,  int, double>::type  // → int
std::conditional<false, int, double>::type  // → double

// --- 实用示例: 根据大小选计算类型 ---
template <typename T>
struct ComputeTypeSelector {
    using type = std::conditional_t<
        (sizeof(T) <= 2),   // 半精度/int8 等小类型
        float,               // → 提升到 float 计算
        T                    // → 原类型计算
    >;
};

// ComputeTypeSelector<half>::type   → float
// ComputeTypeSelector<float>::type  → float
// ComputeTypeSelector<double>::type → double

5.2.4. common_type

// 求多个类型的"公共类型" (类似三目运算符的推导)
std::common_type<int, double>::type         // → double
std::common_type<int, long, float>::type    // → float
std::common_type<int, unsigned int>::type   // → unsigned int

// --- 实用示例: 通用 max ---
template <typename T, typename U>
std::common_type_t<T, U> generic_max(T a, U b) {
    return a > b ? a : b;
}

generic_max(1, 2.5);   // 返回 double
generic_max(1L, 2);    // 返回 long

5.3. SFINAE 相关工具

5.3.1. std::enable_if

// 原理:
//   enable_if<true,  T>::type = T
//   enable_if<false, T>::type = 不存在 → SFINAE

// ---- 用法1: 放在返回类型 ----
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T>
bit_count(T val) {
    return __builtin_popcount(val);
}

// ---- 用法2: 放在模板参数默认值 (更简洁) ----
template <typename T,
          std::enable_if_t<std::is_integral_v<T>, int> = 0>
T bit_count_v2(T val) {
    return __builtin_popcount(val);
}

// ---- 用法3: 放在函数参数 ----
template <typename T>
T bit_count_v3(T val,
               std::enable_if_t<std::is_integral_v<T>>* = nullptr) {
    return __builtin_popcount(val);
}

5.3.2. std::void_t

// void_t<Ts...> = void (无论 Ts 是什么)
// 但如果 Ts 中有无效类型 → 替换失败 → SFINAE

// 检测是否有 .size() 方法
template <typename T, typename = void>
struct has_size : std::false_type {};

template <typename T>
struct has_size<T, std::void_t<decltype(std::declval<T>().size())>>
    : std::true_type {};

// 检测是否有嵌套 value_type
template <typename T, typename = void>
struct has_value_type : std::false_type {};

template <typename T>
struct has_value_type<T, std::void_t<typename T::value_type>>
    : std::true_type {};

// 检测是否支持 << 输出
template <typename T, typename = void>
struct is_printable : std::false_type {};

template <typename T>
struct is_printable<T, std::void_t<
    decltype(std::declval<std::ostream&>() << std::declval<T>())
>> : std::true_type {};

5.3.3. std::declval

// std::declval<T>() 在不构造对象的情况下,"假装"有一个 T 类型的值
// 只能在 decltype / sizeof 等不求值上下文中使用

// 用途: 探测表达式是否合法
template <typename T, typename U>
using add_result_t = decltype(std::declval<T>() + std::declval<U>());

// add_result_t<int, double> → double
// add_result_t<string, string> → string
// add_result_t<int, string> → 替换失败

// 也可以检测成员函数返回类型
template <typename Container>
using iterator_t = decltype(std::declval<Container>().begin());

// iterator_t<std::vector<int>> → std::vector<int>::iterator

5.4. 编译期计算工具

5.4.1. std::integral_constant

// integral_constant: 将编译期常量包装成类型
std::integral_constant<int, 42>       // ::value = 42, ::type = 自身

// 最常见的两个特化:
std::true_type   // = integral_constant<bool, true>
std::false_type  // = integral_constant<bool, false>

// --- 实用: 自定义 trait 继承它们 ---
template <typename T>
struct is_gpu_type : std::false_type {};  // 默认 false

template <>
struct is_gpu_type<__half> : std::true_type {};   // 特化为 true

template <>
struct is_gpu_type<__nv_bfloat16> : std::true_type {};

// 使用:
static_assert(is_gpu_type<__half>::value);
if constexpr (is_gpu_type<T>::value) { /* ... */ }

5.4.2. std::integer_sequence

// 生成编译期整数序列
std::integer_sequence<int, 0, 1, 2, 3, 4>
std::index_sequence<0, 1, 2, 3, 4>          // 简写 (size_t)
std::make_index_sequence<5>                  // 自动生成 0,1,2,3,4

// --- 实用: 展开 tuple ---
template <typename Tuple, size_t... Is>
void print_tuple_impl(const Tuple& t, std::index_sequence<Is...>) {
    ((std::cout << std::get<Is>(t) << " "), ...);  // fold expression
}

template <typename... Ts>
void print_tuple(const std::tuple<Ts...>& t) {
    print_tuple_impl(t, std::make_index_sequence<sizeof...(Ts)>{});
}

print_tuple(std::make_tuple(1, 3.14, "hello"));
// 输出: 1 3.14 hello

// --- 实用: 编译期循环 ---
template <size_t... Is>
void unrolled_copy(float* dst, const float* src, std::index_sequence<Is...>) {
    ((dst[Is] = src[Is]), ...);  // 展开成 dst[0]=src[0]; dst[1]=src[1]; ...
}

5.4.3. if constexpr (C++17)

// 编译期 if — 不满足的分支完全不编译 — 替代大量 enable_if

template <typename T>
auto convert(T val) {
    if constexpr (std::is_same_v<T, std::string>) {
        return std::stoi(val);                      // string → int
    } else if constexpr (std::is_floating_point_v<T>) {
        return static_cast<int>(val);                // float → int
    } else if constexpr (std::is_integral_v<T>) {
        return val;                                  // int → int
    } else {
        static_assert(always_false<T>, "Unsupported type");
    }
}

// 辅助: 永远为 false 的依赖模板 (防止 static_assert 直接触发)
template <typename> constexpr bool always_false = false;

5.6. 常用场景速查表

我想做什么 用什么
判断是不是整数/浮点 is_integral_v, is_floating_point_v
判断两个类型是否相同 is_same_v<T, U>
去掉 const/引用 remove_cv_t<remove_reference_t<T>>decay_t
根据条件选类型 conditional_t<cond, A, B>
求公共类型 common_type_t<T, U>
按条件启用/禁用重载 enable_if_t (C++11) / if constexpr (C++17) / concept (C++20)
检测类型有没有某成员 void_t + 偏特化 / is_detected
检测表达式是否合法 decltype(expr) + declval
判断能不能用 memcpy is_trivially_copyable_v
判断是不是空基类(EBO) is_empty_v
判断能不能调用 is_invocable_v<F, Args...>
编译期循环/展开 index_sequence + fold expression
自定义 bool trait 继承 true_type / false_type
有符号↔无符号 make_signed_t / make_unsigned_t
组合多个条件 conjunction_v / disjunction_v (C++17)

A. 资料




    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • NVIDIA GPU 架构:SP、SM 与 LSU 工作原理详解
  • al-folio 模板定制修改总结
  • al-folio 本地部署记录(Ubuntu 24.04)
  • 道格拉斯-普克算法(Douglas–Peucker algorithm)
  • CMake支持库收集