C++ Traits
C++ Traits 是一种元编程技术,将类型信息封装在一个类中,供算法或其他模板使用,以期达到从算法/逻辑中分离类型信息的目的。比如元素类型、一些常量定义、指令选择、对齐要求等等。
| 问题 | Traits 如何解决 |
|---|---|
| 内置类型(int, float)不能添加成员 | Traits 外挂信息 |
| 不同类型需要不同的算法策略 | 模板特化选择策略 |
| 想要零开销的编译期多态 | 全部在编译期决议 |
| 接口与实现解耦 | 算法只依赖 Traits 接口 |
1. 自定义 Traits 常见模式
1.1 类型信息萃取
目的:从不同类型中统一提取需要的信息。
// 主模板 (可以留空或给默认值)
template <typename T>
struct NumericTraits;
// 对 float 特化
template <>
struct NumericTraits<float> {
using type = float;
using compute_type = float; // 计算时用的类型
using accum_type = float; // 累加器类型
static constexpr int bits = 32;
static constexpr float epsilon = 1e-7f;
static constexpr float max_val = 3.4028235e+38f;
};
// 对 half 特化
template <>
struct NumericTraits<__half> {
using type = __half;
using compute_type = float; // half 计算时提升到 float
using accum_type = float;
static constexpr int bits = 16;
};
// 对 int8_t 特化
template <>
struct NumericTraits<int8_t> {
using type = int8_t;
using compute_type = int32_t; // int8 用 int32 累加
using accum_type = int32_t;
static constexpr int bits = 8;
};
// --- 使用 ---
template <typename T>
void gemm_kernel() {
using Traits = NumericTraits<T>;
using AccumType = typename Traits::accum_type;
AccumType accumulator = 0; // 自动选择合适的累加器类型
// ...
}
1.2. 策略选择
目的:根据类型选择不同的算法实现。
// 内存拷贝策略 Traits
template <typename T, int Size>
struct CopyTraits;
// 小数据: 逐元素拷贝
template <typename T>
struct CopyTraits<T, 1> {
static void copy(T* dst, const T* src) {
*dst = *src;
}
};
// 4字节对齐: 用 uint32_t 一次拷贝
template <>
struct CopyTraits<float, 4> {
static void copy(float* dst, const float* src) {
// 使用向量化加载
*reinterpret_cast<float4*>(dst) = *reinterpret_cast<const float4*>(src);
}
};
// 128位对齐: 用 LDG.128 指令
template <>
struct CopyTraits<__half, 8> {
static void copy(__half* dst, const __half* src) {
// 8 个 half = 128 bit, 用 128-bit load
uint4 tmp = *reinterpret_cast<const uint4*>(src);
*reinterpret_cast<uint4*>(dst) = tmp;
}
};
1.3. 特征探测
目的:检测类型是否具有某些特征,以启用/禁用特定代码路径。
// 检测类型是否有 .size() 方法
template <typename T, typename = void>
struct has_size_method : std::false_type {};
template <typename T>
struct has_size_method<T, std::void_t<decltype(std::declval<T>().size())>>
: std::true_type {};
// --- 使用 ---
static_assert(has_size_method<std::vector<int>>::value, ""); // true
static_assert(!has_size_method<int>::value, ""); // true
// 检测类型是否有嵌套类型 value_type
template <typename T, typename = void>
struct has_value_type : std::false_type {};
template <typename T>
struct has_value_type<T, std::void_t<typename T::value_type>>
: std::true_type {};
1.4. 标签分发
目的:通过标签类型分发不同的实现。
// Tags
struct RowMajorTag {};
struct ColMajorTag {};
// Layout Traits
template <typename Layout>
struct LayoutTraits;
template <>
struct LayoutTraits<RowMajorTag> {
static int index(int row, int col, int ldim) {
return row * ldim + col;
}
static constexpr char name[] = "RowMajor";
};
template <>
struct LayoutTraits<ColMajorTag> {
static int index(int row, int col, int ldim) {
return col * ldim + row;
}
static constexpr char name[] = "ColMajor";
};
// 算法根据 tag 自动适配
template <typename LayoutTag>
void fill_matrix(float* data, int M, int N, int ld) {
using LTraits = LayoutTraits<LayoutTag>;
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
data[LTraits::index(i, j, ld)] = static_cast<float>(i * N + j);
}
2. Traits 核心技术手段
2.1. 特化/偏特化
- 完全特化(Template Specialization):为特定类型提供完整实现。
- 偏特化(Partial Specialization):为满足某些条件的类型提供实现(如指针类型、数组类型等)。
模板特化例子:
template<typename T>
struct my_is_void {
static const bool value = false;
};
template<>
struct my_is_void<void> {
static const bool value = true;
};
偏特化例子:
template<typename T>
struct is_pointer {
static const bool value = false;
};
template<typename T>
struct is_pointer<T*> {
static const bool value = true;
};
2.2. typename
在 Traits 中经常需要定义嵌套类型(如 type、compute_type 等),使用 typename 来显式告诉编译器该名字是一个类型。
template <typename T>
struct NumericTraits {
using type = T; // 定义一个嵌套类型
};
template <typename T>
void foo() {
using MyType = typename NumericTraits<T>::type; // 使用嵌套类型
// ...
}
2.3. SFINAE(Substitution Failure Is Not An Error)
在 SFINAE 中,如果某个表达式无效(如访问了不存在的成员),编译器会将该特化视为无效,而不是报错,从而允许其他特化继续匹配。
在 SFINAE 中,编译器优先特化/偏特化版本,C++ 偏特化规则:更特化(更具体)的版本优先。最后才是主模板版本。
2.3.1. void_t 和 SFINAE 结合实现特征探测
// void_t 是 SFINAE 的瑞士军刀
template <typename... Ts>
using void_t = void; // C++17 标准已有
// 主模板:利用 void_t 探测嵌套类型
template <typename T, typename = void>
struct element_type_of {
using type = T; // 默认: 类型本身
};
// 偏特化:如果 T 有 element_type 嵌套类型,就用它
template <typename T>
struct element_type_of<T, void_t<typename T::element_type>> {
using type = typename T::element_type; // 如果有 element_type 就用它
};
// 使用
// element_type_of<std::unique_ptr<int>>::type → 得到 int
// element_type_of<float>::type → 得到 float
void_t是一个通用的的类型萃取容器,能将任意数量的表达式包裹为 void 类型。void_t<X> 本身永远是 void,但若 X 是非法类型表达式,模板替换就会失败。
上述代码片段中,以 float 为例,偏特化得到:
element_type_of< float, void_t<typename float::element_type> >
^^^^^^^^^^^^^^^^^^^^^^^^^^^
float 没有 ::element_type 成员
→ 类型表达式非法 → 替换失败
以 std::unique_ptr<int> 为例,偏特化得到:
void_t<typename unique_ptr<int>::element_type>
= void_t<int> ← ::element_type = int,合法!
= void ← void_t 永远返回 void
最终得到:
element_type_of< unique_ptr<int>, void >
2.3.2. SFINAE 结合 enable_if 实现函数重载
// enable_if 的原理:
template <bool Cond, typename T = void>
struct enable_if {}; // 默认: 没有 type 成员
template <typename T>
struct enable_if<true, T> {
using type = T; // 条件为 true 时才有 type
};
// 使用: 条件不满足 → enable_if<false>::type 不存在 → SFINAE
template <typename T>
typename std::enable_if<std::is_integral<T>::value, T>::type
double_it(T x) {
return x * 2;
}
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, T>::type
double_it(T x) {
return x * 2.0;
}
double_it(42); // is_integral = true → 第一个 ✅
double_it(3.14); // is_floating = true → 第二个 ✅
2.4. constexpr 与 static 编译期常量
template <typename T>
struct AlignmentTraits {
// 编译期常量
static constexpr int alignment = alignof(T);
static constexpr int size_bytes = sizeof(T);
// 编译期函数 (C++14)
static constexpr int max_vector_width() {
if constexpr (sizeof(T) == 1) return 16; // 128-bit / 8-bit
if constexpr (sizeof(T) == 2) return 8; // 128-bit / 16-bit
if constexpr (sizeof(T) == 4) return 4; // 128-bit / 32-bit
return 1;
}
};
2.5. 嵌套类型别名
template <typename T>
struct AlignmentTraits {
// 编译期常量
static constexpr int alignment = alignof(T);
static constexpr int size_bytes = sizeof(T);
// 编译期函数 (C++14)
static constexpr int max_vector_width() {
if constexpr (sizeof(T) == 1) return 16; // 128-bit / 8-bit
if constexpr (sizeof(T) == 2) return 8; // 128-bit / 16-bit
if constexpr (sizeof(T) == 4) return 4; // 128-bit / 32-bit
return 1;
}
};
3. CuTe 中使用 Traits 的例子
以SM80_CP_ASYNC_CACHEALWAYS为例,其 Traits 定义如下:
// CuTe 风格的 Copy_Traits (简化示意)
// 对应 cp.async 指令的 Traits
template <typename CopyInstr>
struct Copy_Traits;
// 针对 SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t> 的特化
template <>
struct Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>> {
// ---- 关联类型 ----
// 指令操作的寄存器类型
// 源: 1个 gmem 指针 (uint128_t)
// 目标: 1个 smem 指针 (uint128_t)
using SRegisters = cute::tuple<uint128_t const*>; // source (gmem ptr)
using DRegisters = cute::tuple<uint128_t*>; // dest (smem ptr)
// Layout 描述: 一个线程搬运多少数据、如何排布
// ThrID: 哪些线程参与
// ValLayoutSrc / ValLayoutDst: 值的布局
using ThrID = Layout<_1>; // 1个线程
using ValLayoutSrc = Layout<_1>; // 搬运1个128-bit值
using ValLayoutDst = Layout<_1>;
// ---- 核心操作 ----
template <typename TS, typename TD>
CUTE_HOST_DEVICE static void copy(TS const& src, TD& dst) {
// 调用 PTX 内联汇编
// cp.async.ca.shared.global [dst], [src], 16;
cute::cp_async<16>(dst, src);
}
};
Copy_Atom 如何使用 Traits
// 简化版 Copy_Atom
template <typename Traits, typename Element>
struct Copy_Atom {
using TraitsType = Traits;
using ElementType = Element;
// 从 Traits 中提取信息
using SrcRegisters = typename Traits::SRegisters;
using DstRegisters = typename Traits::DRegisters;
// 执行拷贝 — 委托给 Traits
template <typename SrcEngine, typename SrcLayout,
typename DstEngine, typename DstLayout>
CUTE_HOST_DEVICE
void copy(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst) const
{
// 实际调用 Traits::copy
Traits::copy(src.data(), dst.data());
}
};
// ===== 用户代码 =====
// 选择 Traits → 决定用什么指令
using CopyTraits = Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>>;
// 创建 Atom
using CopyAtom = Copy_Atom<CopyTraits, cute::half_t>;
// 构建 TiledCopy (多线程协作版本)
using TiledCopy = decltype(
make_tiled_copy(
CopyAtom{},
Layout<Shape<_32, _4>>{}, // 线程布局: 32x4 = 128线程
Layout<Shape<_1, _8>>{} // 每线程搬运的值布局
)
);
5. C++ 标准库中 Traits 相关的常用工具
5.1. 类型判断
5.1.1. 基本类型判断
返回
::value,bool 类型,C++17 起可用_v后缀简写。
td::is_void<T> // T 是 void?
std::is_null_pointer<T> // T 是 std::nullptr_t? (C++14)
std::is_integral<T> // T 是整数类型?(bool/char/int/long...)
std::is_floating_point<T> // T 是浮点?(float/double/long double)
std::is_arithmetic<T> // T 是算术类型?(整数 或 浮点)
std::is_enum<T> // T 是枚举?
std::is_class<T> // T 是 class/struct?
std::is_function<T> // T 是函数类型?
std::is_array<T> // T 是数组类型?(T[] 或 T[N])
5.1.2. 复合类型判断
std::is_pointer<T> // T 是指针?(int*, 但不含成员指针)
std::is_reference<T> // T 是引用?(左值引用 或 右值引用)
std::is_lvalue_reference<T> // T 是左值引用?
std::is_rvalue_reference<T> // T 是右值引用?
std::is_member_pointer<T> // T 是成员指针?
std::is_const<T> // T 有顶层 const?
std::is_volatile<T> // T 有顶层 volatile?
std::is_signed<T> // T 是有符号类型?
std::is_unsigned<T> // T 是无符号类型?
5.1.3. 类型关系判断
// ===== 两个类型的关系 =====
std::is_same<T, U> // T 和 U 是完全相同的类型?
std::is_base_of<Base, Derived> // Base 是 Derived 的基类?
std::is_convertible<From, To> // From 能隐式转换为 To?
std::is_assignable<T, U> // T = U 赋值合法?
std::is_constructible<T, Args...> // T(Args...) 构造合法?
std::is_invocable<F, Args...> // F(Args...) 可调用? (C++17)
std::is_invocable_r<R, F, Args...> // F(Args...) 可调用且返回 R? (C++17)
5.2. 类型变换
5.2.1. 修饰符增删
// ===== 去除修饰 =====
std::remove_const<T> // const int → int
std::remove_volatile<T> // volatile int → int
std::remove_cv<T> // const volatile int → int
std::remove_reference<T> // int&/int&& → int
std::remove_pointer<T> // int* → int
std::remove_extent<T> // int[10] → int
std::remove_all_extents<T> // int[3][4] → int
// ===== 添加修饰 =====
std::add_const<T> // int → const int
std::add_volatile<T> // int → volatile int
std::add_cv<T> // int → const volatile int
std::add_lvalue_reference<T> // int → int&
std::add_rvalue_reference<T> // int → int&&
std::add_pointer<T> // int → int*
5.2.2. decay
// decay 模拟"按值传参"时的类型退化:
// - 去掉引用
// - 去掉顶层 cv
// - 数组 → 指针
// - 函数 → 函数指针
std::decay<const int&> // → int
std::decay<int[10]> // → int*
std::decay<int(double)> // → int(*)(double)
std::decay<const int&&> // → int
// --- 实用示例: 存储任意传入的值 ---
template <typename T>
struct Storage {
using StoredType = std::decay_t<T>;
StoredType value;
Storage(T&& v) : value(std::forward<T>(v)) {}
};
5.2.3. std::conditional
// ===== conditional: 编译期三目运算符 =====
std::conditional<true, int, double>::type // → int
std::conditional<false, int, double>::type // → double
// --- 实用示例: 根据大小选计算类型 ---
template <typename T>
struct ComputeTypeSelector {
using type = std::conditional_t<
(sizeof(T) <= 2), // 半精度/int8 等小类型
float, // → 提升到 float 计算
T // → 原类型计算
>;
};
// ComputeTypeSelector<half>::type → float
// ComputeTypeSelector<float>::type → float
// ComputeTypeSelector<double>::type → double
5.2.4. common_type
// 求多个类型的"公共类型" (类似三目运算符的推导)
std::common_type<int, double>::type // → double
std::common_type<int, long, float>::type // → float
std::common_type<int, unsigned int>::type // → unsigned int
// --- 实用示例: 通用 max ---
template <typename T, typename U>
std::common_type_t<T, U> generic_max(T a, U b) {
return a > b ? a : b;
}
generic_max(1, 2.5); // 返回 double
generic_max(1L, 2); // 返回 long
5.3. SFINAE 相关工具
5.3.1. std::enable_if
// 原理:
// enable_if<true, T>::type = T
// enable_if<false, T>::type = 不存在 → SFINAE
// ---- 用法1: 放在返回类型 ----
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T>
bit_count(T val) {
return __builtin_popcount(val);
}
// ---- 用法2: 放在模板参数默认值 (更简洁) ----
template <typename T,
std::enable_if_t<std::is_integral_v<T>, int> = 0>
T bit_count_v2(T val) {
return __builtin_popcount(val);
}
// ---- 用法3: 放在函数参数 ----
template <typename T>
T bit_count_v3(T val,
std::enable_if_t<std::is_integral_v<T>>* = nullptr) {
return __builtin_popcount(val);
}
5.3.2. std::void_t
// void_t<Ts...> = void (无论 Ts 是什么)
// 但如果 Ts 中有无效类型 → 替换失败 → SFINAE
// 检测是否有 .size() 方法
template <typename T, typename = void>
struct has_size : std::false_type {};
template <typename T>
struct has_size<T, std::void_t<decltype(std::declval<T>().size())>>
: std::true_type {};
// 检测是否有嵌套 value_type
template <typename T, typename = void>
struct has_value_type : std::false_type {};
template <typename T>
struct has_value_type<T, std::void_t<typename T::value_type>>
: std::true_type {};
// 检测是否支持 << 输出
template <typename T, typename = void>
struct is_printable : std::false_type {};
template <typename T>
struct is_printable<T, std::void_t<
decltype(std::declval<std::ostream&>() << std::declval<T>())
>> : std::true_type {};
5.3.3. std::declval
// std::declval<T>() 在不构造对象的情况下,"假装"有一个 T 类型的值
// 只能在 decltype / sizeof 等不求值上下文中使用
// 用途: 探测表达式是否合法
template <typename T, typename U>
using add_result_t = decltype(std::declval<T>() + std::declval<U>());
// add_result_t<int, double> → double
// add_result_t<string, string> → string
// add_result_t<int, string> → 替换失败
// 也可以检测成员函数返回类型
template <typename Container>
using iterator_t = decltype(std::declval<Container>().begin());
// iterator_t<std::vector<int>> → std::vector<int>::iterator
5.4. 编译期计算工具
5.4.1. std::integral_constant
// integral_constant: 将编译期常量包装成类型
std::integral_constant<int, 42> // ::value = 42, ::type = 自身
// 最常见的两个特化:
std::true_type // = integral_constant<bool, true>
std::false_type // = integral_constant<bool, false>
// --- 实用: 自定义 trait 继承它们 ---
template <typename T>
struct is_gpu_type : std::false_type {}; // 默认 false
template <>
struct is_gpu_type<__half> : std::true_type {}; // 特化为 true
template <>
struct is_gpu_type<__nv_bfloat16> : std::true_type {};
// 使用:
static_assert(is_gpu_type<__half>::value);
if constexpr (is_gpu_type<T>::value) { /* ... */ }
5.4.2. std::integer_sequence
// 生成编译期整数序列
std::integer_sequence<int, 0, 1, 2, 3, 4>
std::index_sequence<0, 1, 2, 3, 4> // 简写 (size_t)
std::make_index_sequence<5> // 自动生成 0,1,2,3,4
// --- 实用: 展开 tuple ---
template <typename Tuple, size_t... Is>
void print_tuple_impl(const Tuple& t, std::index_sequence<Is...>) {
((std::cout << std::get<Is>(t) << " "), ...); // fold expression
}
template <typename... Ts>
void print_tuple(const std::tuple<Ts...>& t) {
print_tuple_impl(t, std::make_index_sequence<sizeof...(Ts)>{});
}
print_tuple(std::make_tuple(1, 3.14, "hello"));
// 输出: 1 3.14 hello
// --- 实用: 编译期循环 ---
template <size_t... Is>
void unrolled_copy(float* dst, const float* src, std::index_sequence<Is...>) {
((dst[Is] = src[Is]), ...); // 展开成 dst[0]=src[0]; dst[1]=src[1]; ...
}
5.4.3. if constexpr (C++17)
// 编译期 if — 不满足的分支完全不编译 — 替代大量 enable_if
template <typename T>
auto convert(T val) {
if constexpr (std::is_same_v<T, std::string>) {
return std::stoi(val); // string → int
} else if constexpr (std::is_floating_point_v<T>) {
return static_cast<int>(val); // float → int
} else if constexpr (std::is_integral_v<T>) {
return val; // int → int
} else {
static_assert(always_false<T>, "Unsupported type");
}
}
// 辅助: 永远为 false 的依赖模板 (防止 static_assert 直接触发)
template <typename> constexpr bool always_false = false;
5.6. 常用场景速查表
| 我想做什么 | 用什么 |
|---|---|
| 判断是不是整数/浮点 | is_integral_v, is_floating_point_v |
| 判断两个类型是否相同 | is_same_v<T, U> |
| 去掉 const/引用 | remove_cv_t<remove_reference_t<T>> 或 decay_t |
| 根据条件选类型 | conditional_t<cond, A, B> |
| 求公共类型 | common_type_t<T, U> |
| 按条件启用/禁用重载 | enable_if_t (C++11) / if constexpr (C++17) / concept (C++20) |
| 检测类型有没有某成员 | void_t + 偏特化 / is_detected |
| 检测表达式是否合法 | decltype(expr) + declval |
| 判断能不能用 memcpy | is_trivially_copyable_v |
| 判断是不是空基类(EBO) | is_empty_v |
| 判断能不能调用 | is_invocable_v<F, Args...> |
| 编译期循环/展开 | index_sequence + fold expression |
| 自定义 bool trait | 继承 true_type / false_type |
| 有符号↔无符号 | make_signed_t / make_unsigned_t |
| 组合多个条件 | conjunction_v / disjunction_v (C++17) |
A. 资料
Enjoy Reading This Article?
Here are some more articles you might like to read next: