《产生式元编程》第七章 巧活用折叠表达式
Introduction
模板是第一阶段元编程最核心的工具,中篇以两章五星难度的内容开头,深入纵览其核心技术与诸般妙诀。本章要讨论的元编程工具——Fold Expressions,依旧处于第一阶段,是 C++17 引入的一个跨越性产生式特性。
该特性从更高层面抽象了参数包拆解方式,消除了传统递归所带来的复杂性,是非常有用的编译期遍历方式。因此,这个特性是产生式元编程中的关键部分,这也是单独分配一章深度讨论的原因。
这个特性也并不完善,C++26/29 依旧会增加与其相关的特性,增加哪些?为什么加?痛点在哪儿?这些问题也将在本章给出答案。
Fold
Fold 这个概念来自函数式编程,本身指的是一类高阶函数,这些函数使用给定的组合操作来分析递归数据结构,并通过递归处理其组成部分的结果来重新组合,最终构建出一个返回值。简而言之,Fold 能够递归地遍历数据结构中的每个元素,并通过一个组合函数将这些元素的值合并为一个结果。这个组合函数通常定义了如何将两个值合并在一起,以及如何处理基础情况(如空数据结构)。
这种从递归数据结构中提取信息的数学概念称为 Catamorphism(源自古希腊语:κατά "向下" 和 μορφή "形式,形状"),表示从一个初始代数到其他代数的唯一同态映射。编程中,目标代数通常就是一个单一的值或结果。若从词的本意来理解,其实指的就是将一复杂数据结构向下拆解成简单数据结构的过程,这个过程递归地利用一个组合函数来完成。Catamorphism 通过抽象化递归数据结构的处理方式,提供了一种通用的模式来遍历和处理递归数据结构,使得代码更具有可读性和可维护性,适用于各种场景,如求和、乘积、查找、过滤等。
Fold 只是 Catamorphism 的一个具体实现,主要是指对列表和序列的处理。例如,对列表 [1, 2, 3, 4]
求和,可以这样表示:
foldl (+) 0 [1, 2, 3, 4] = (((0 + 1) + 2) + 3) + 4 = 10
foldr (+) 0 [1, 2, 3, 4] = 1 + (2 + (3 + (4 + 0))) = 10
想必大家也不陌生,foldl
是从左向右应用函数的方式,称为左折叠,foldr
与之相反,称为右折叠。若是操作类型满足交换律,不论选择哪种方式,结果都一样。由此结构,也可以看到 Fold 通常包含的三个参数:
- 二元函数:
+
,定义了如何将列表的两个元素合并成一个值; - 初始值:
0
,折叠操作的开始值,它与列表的第一个元素一起传递给二元函数; - 列表:
[1, 2, 3, 4]
,折叠操作的元素集合。
当然,倘是一元函数,则不需要初始值,例如,拼接字符串:
foldl (++) ["a", "b", "c", "d"] = (((("a" ++ "b") ++ "c") ++ "d")) = "abcd"
foldr (++) ["a", "b", "c", "d"] = "a" ++ ("b" ++ ("c" ++ "d")) = "abcd"
正是这种更高一级的抽象方式,Fold 这个概念才能够简化常规的递归方式,以一种更加易于人类理解的方式表达拆解逻辑,增加代码的可读性的同时,也简化了编码效率。
C++ Fold
std::accumulate
就是C++ 提供的一个 Fold 函数,包含前面所说的三个参数。基本形式如下:
// left fold
std::accumulate(begin, end, initval, func)
// right fold
std::accumulate(rbegin, rend, initval, func)
下面是一个例子:
// fr. https://en.cppreference.com/w/cpp/algorithm/accumulate
#include <functional>
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
int main()
{
std::vector<int> v{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int sum = std::accumulate(v.begin(), v.end(), 0);
int product = std::accumulate(v.begin(), v.end(), 1, std::multiplies<int>());
auto dash_fold = [](std::string a, int b)
{
return std::move(a) + '-' + std::to_string(b);
};
std::string s = std::accumulate(std::next(v.begin()), v.end(),
std::to_string(v[0]), // start with first element
dash_fold);
// Right fold using reverse iterators
std::string rs = std::accumulate(std::next(v.rbegin()), v.rend(),
std::to_string(v.back()), // start with last element
dash_fold);
std::cout << "sum: " << sum << '\n'
<< "product: " << product << '\n'
<< "dash-separated string: " << s << '\n'
<< "dash-separated string (right-folded): " << rs << '\n';
}
输出为:
sum: 55
product: 3628800
dash-separated string: 1-2-3-4-5-6-7-8-9-10
dash-separated string (right-folded): 10-9-8-7-6-5-4-3-2-1
函数刚好接受一个列表、一个初始值、一个二元定制函数作为参数,所以类似的需求皆可摆脱传统的遍历方式,表达起来更加简单。但是 std::accumulate
只支持二元函数,且表意不够广泛,因而 C++23 又增加了 Ranges fold 系列算法。基本形式如下:
std::ranges::fold_left(range, initval, func)
std::ranges::fold_right(range, initval, func)
std::ranges::fold_left_first(range, func)
std::ranges::fold_right_last(range, func)
新的系列算法更加顾名思义,同时支持一元函数和二元函数,进一步简化了折叠方式。
同样提供一个例子:
// fold algorithms
int xs[] = { 1, 2, 3, 4, 5 };
auto concatl = [](std::string s, int i) { return s + std::to_string(i); };
auto concatr = [](int i, std::string s) { return s + std::to_string(i); };
auto fold_left = ranges::fold_left(xs, std::string(), concatl);
fmt::print("fold left: {}\n", fold_left);
auto fold_right = ranges::fold_right(xs, std::string(), concatr);
fmt::print("fold right: {}\n", fold_right);
// Output:
// fold left: 12345
// fold right: 54321
C++ Fold Expressions
Fold 函数主要适用于列表或其他线性数据结构,而 C++17 Fold Expressions 则适用于可变模板参数包。不同的是,前者是 Library 级别的特性,而后者却是 Language 级别的特性,可用性更强。
Fold Expressions 同样支持一元及二元的左折叠和右折叠,形式如下:
( pack op ... ) // Unary right fold
( ... op pack ) // Unary left fold
( pack op ... op init ) // Binary right fold
( init op ... op pack ) // Binary left fold
逻辑其实都一样,只是语法形式稍异而已。...
在参数包的左边,就属于左折叠,在右边,就属于右折叠。但是,在二元折叠中,不能同时包含参数包,例如:
// fr. C++20 standard §7.5.6 (ISO/IEC 14882:2020)
template<typename ...Args>
bool f(Args ...args) {
return (true && ... && args); // OK
}
template<typename ...Args>
bool f(Args ...args) {
return (args + ... + args); // error: both operands contain unexpanded packs
}
而在一元折叠中,参数包若为空,只有以下三个操作符具有合法的默认值:
Operator | Value when parameter pack is empty |
---|---|
&& |
true |
\|\| |
false |
, |
void() |
这三个操作符也恰恰是各种高级技巧的基石。
Smart Tricks with Fold Expressions
Fold 是更加抽象化的遍历方式,Fold Expressions 的核心作用就是替代传统的递归遍历方式。但是,要精细化控制这种遍历方式的各种细节,例如条件、中断、中间值、下标等,便需要各种高级技巧了。
本节便分别展示各种小技巧,将它们分布在各个算法当中。
Conditions and Counting
根据一个 Predicate 函数,计算符合条件的元素个数。
all_of
计算是否所有元素都满足条件,any_of
计算是否任一元素满足条件,count_of
计算满足条件的个数。实现如下:
// Check whether all elements matches a predicate.
auto all_of(auto F, auto... args) -> bool {
return (F(args) && ...);
}
// Check whether any elements matches a predicate.
auto any_of(auto F, auto... args) -> bool {
return (F(args) || ...);
}
// Count the elements matches a predicate.
auto count_of(auto Pred, auto... args) -> int {
return (Pred(args) + ...);
}
标准中存在类似的算法给容器使用,实现都采用 std::find_if
之类的算法,查找算法内部又都涉及传统的循环遍历。对于参数包,若是用传统的递归来完成此类操作,相较也会麻烦,而 Fold Expressions 这种更高一级的遍历方式则显得灵活而简洁。
本技巧主要利用了两个特性,一个是 &&
和 ||
所具有的 short-circuit 评估能力,可以用来实现条件和中断,另一个是 bool
到 int
之间所存在的隐式转换,可以用来计数。
Random Access
任意访问参数包某个索引指向的元素。
首先,若是参数包的元素属于同构类型,可以通过以下方式访问其首尾元素。
// Find the first element.
auto first_of(auto... args) -> std::common_type_t<decltype(args)...> {
std::common_type_t<decltype(args)...> result;
((result = args, true) || ...);
return result;
}
// Find the last element.
auto last_of(auto... args) -> std::common_type_t<decltype(args)...> {
std::common_type_t<decltype(args)...> result;
(result = (args, ...));
return result;
}
本处技巧主要是利用 ||
和 ,
的特性,保存中间值的过程中,决定是否继续往下走。同时,还用到 =
让右边的表达式先计算,从而保存最终结果。
其次,若是参数包是异构类型,可以借助 std::tuple
的索引式访问算法来实现,返回值只能依靠自动推导。
auto generic_first_of(auto... args) {
auto values = std::forward_as_tuple(args...);
return std::get<1>(values);
}
auto generic_last_of(auto... args) {
auto values = std::forward_as_tuple(args...);
return std::get<sizeof...(args)-1>(values);
}
这种方式本质就是利用已有的 std::tuple
算法来达到目的,虽说复杂度降低,但却要构造一个额外的对象。
最后,若是参数包的元素属于同构类型,不借助 std::tuple
,可以通过以下方式实现索引式访问。
// Find the nth element.
template <std::size_t I>
auto nth_of(auto... args) -> std::common_type_t<decltype(args)...> {
std::common_type_t<decltype(args)...> result;
std::size_t n{};
((n++ == I ? (result = args, true) : false) || ...);
return result;
}
手法结合了前面几个技巧,再以三目运算符作为条件,分发逻辑,false
时继续往下遍历,true
时结束遍历。
Maximum and Minimum
取同构元素列表的最大最小值。
同样需要保存中间结果,但不需要中断遍历流程,实现如下:
// Find the minimum element.
auto min_of(auto... args) -> std::common_type_t<decltype(args)...> {
auto min = (args, ...);
((min > args ? min = args : 0), ...);
return min;
}
// Find the maximum element.
auto max_of(auto... args) -> std::common_type_t<decltype(args)...> {
auto max = (args, ...);
((max < args ? max = args : 0), ...);
return max;
}
不中断的情况下,使用 ,
展开更加便捷。
Reverse Packs
逆转列表元素位置,返回转换后的列表。
因为需要返回一个列表,所以只得借助 std::tuple
,再反转 std::tuple
。实现如下:
auto reverse_of(auto... args) {
auto tuple = std::make_tuple(args...);
return [&tuple]<auto... I>(std::index_sequence<I...>) {
return std::make_tuple(std::get<sizeof...(args)-1-I>(std::forward<decltype(tuple)>(tuple))...);
}(std::index_sequence_for<decltype(args)...>{});
}
此处便出现了第六章介绍的 Compile-time for,这种技巧可以在编译期遍历并操作 std::tuple
,本质就是创造一个索引参数包,再以 Fold Expressions 遍历。
Overload Pattern
与 Fold Expressions 相关联的另一个技术称为 Overload pattern,这项技术通过展开参数包实现多继承,以在视觉层面模拟 Lambda 重载。代码只有两行:
template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
短短两行代码中,除了使用参数包展开实现多继承,还借助 Using-declration 来绕开重载合并规则,避免重载歧义。这些特性和 Fold Expressions 类似,都能展开参数包,生成重复代码。
Lambda 重载是一种灵活的定制方式,以前展示过对象工厂和抽象工厂的应用,下面看一个与 Fold Expressions 一起使用的新例子。
Implement any_visit with Fold Expressions
std::any
采用类型擦除技术,实现了异构类型表示,可以和容器类型结合起来构成异构容器。std::variant
也能够起到类似的作用,标准提供 std::visit
来访问元素,因此可以直接使用已有算法来遍历:
using value_type = std::variant<int, double, std::string>;
std::vector<value_type> container;
container.push_back(5);
container.push_back(0.42);
container.push_back("hello");
// Iterate the heterogeneous container
std::ranges::for_each(container, [](const value_type& value) {
std::visit([](const auto& x){ std::print("{} ", x); }, value);
});
而 std::any
没有对应的 visit
访问函数,只能通过下面这种 type-switch 方式访问:
for (const auto& a : container) {
if (a.type() == typeid(int)) {
const auto& value = std::any_cast<int>(a);
} else if (a.type() == typeid(const char*)) {
const auto& value = std::any_cast<const char*>(a);
} else if (a.type() == typeid(bool)) {
const auto& value = std::any_cast<bool>(a);
}
}
重复、变化、繁琐……于是实现一个 any_visit
的想法顿时出现,而这个遍历就可以用 Fold Expressions 来抽象得更高一级。实现为:
template <class... Ts>
void any_visit(auto f, const std::any& a) {
((std::type_index(a.type()) == std::type_index(typeid(Ts))
&& (f(std::any_cast<Ts>(a)), true)) || ...);
}
可以细品一下这个实现是如何借助 &&
、||
和 ,
消除 for
和 if
的,技巧都是前面讲过的内容。有了这个工具,便可以像 std::visit
那样,借助算法来迭代 std::any
构成的异构容器。例子:
std::vector<std::any> container { 5, 0.42, "hello", false };
// Output: 5 0.42 hello boolean: false
std::ranges::for_each(container, [](const auto& a) {
any_visit<int, double, const char*>([](const auto& x) { std::print("{} ", x); }, a);
any_visit<bool>([](const auto& x) { std::print("boolean: {} ", x); }, a);
});
简单是简单,但由于 Fold Expressions 要借助参数包展开,模板参数的变化性依旧没有消除,而这些信息其实可以表现到 Lambda 参数之中,以 Overload pattern 封装这部分变化。即目标用法变成这样:
std::ranges::for_each(container, [](const auto& a) {
any_visit(overloaded {
[](int x) { std::print("int: {} ", x); },
[](double x) { std::print("double: {} ", x); },
[](std::string_view x) { std::print("string: {} ", x); },
[](bool x) { std::print("bool: {} ", x); }
}, a);
});
如此一来,不仅可以精确处理每一个异构类型,还不用重复调用 any_visit
,同时也消除了显式模板参数。overloaded
在上一节已然介绍,那么现在只剩下一个关键问题——如何获取 Lambda 的参数类型?解决了这个问题,实现便也水到渠成了。
Lambda 是一个可以携带状态的函数,其实现是一个含有 operator()
重载的匿名类,捕获的参数作为匿名类的数据成员直接初始化。Lambda 使用时调用的便是这个重载的 operator()
,返回的类型就是匿名类的类型,称为 closure type。因此,问题进一步转化为如何获取成员函数 operator()
的参数类型,通过第五、第六章的高级模板内容,获取起来犹如探囊取物。
方法就是把想要的类型,通过模板参数,显式写出来:
// For a function pointer
template<typename R, typename Arg, typename... Rest>
Arg extract_first_arg(R(*) (Arg, Rest...));
// For a member function pointer without a qualifier
template<typename R, typename F, typename Arg, typename... Rest>
Arg extract_first_arg(R(F::*) (Arg, Rest...));
// For a const-qualified member function pointer
template<typename R, typename F, typename Arg, typename... Rest>
Arg extract_first_arg(R(F::*) (Arg, Rest...) const);
// ...
这里只写了支持函数指针和带基本修饰的 Lambda 函数,更多修饰可以接着往下写。这些函数都属于稻草人函数,不实际使用,只提取类型,故无需实现。接着,通过 decltype()
将函数模板的返回类型提取出来:
template<typename L>
using lambda_arg_t = decltype(extract_first_arg(&L::operator()));
至此,最复杂的问题便解决了。
最后,利用 Overload pattern 和 Fold expressions 访问 std::any
构成的异构容器。代码为:
template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
template<typename... Lam>
void any_visit(const overloaded<Lam...>& f, auto&& any)
{
((std::type_index(any.type()) == std::type_index(typeid(lambda_arg_t<Lam>))
&& (f(std::any_cast<lambda_arg_t<Lam>>(any)), true)) || ...);
}
寥寥数行代码,便解决了一个相对复杂的问题,这就是 Fold expressions 的妙用。
Related Features and Discussions
Fold expressions 是专门针对参数包的 Fold 特性,好用是好用,但是前提得先有参数包,否则也是巧妇难为无米之炊。
比如我们要循环输出 5 次 hello fold expressions
,那前提是先得有五个模板参数,在哪儿凭空产生固定个数的模板参数呢?这其实就是第六章所介绍的 Compile-time for 技术,采用该技术,可以这样实现:
[]<auto... Is>(std::index_sequence<Is...>) {
((std::println("hello fold expressions"), Is), ...);
}(std::make_index_sequence<5>{});
这能达到预期效果,但显得复杂,更本质的问题在于 C++ 缺少直接创建参数包的能力,std::index_sequence
也只是扬汤止沸的产物,没有解决根本问题。而 Circle 便支持直接创建参数包,于是可以简洁地完成以下功能:
struct obj_t {
int x;
double y;
std::string s;
};
int main() {
auto f = [](const char* name, auto... i) {
std::cout<< name<< ":\n";
std::cout<< " "<< i<< "\n" ...;
};
// Just expand a pack into an argument list.
f("integers", int...(5) ...);
obj_t obj { 100, 3.14, "A string" };
f("object", obj...);
}
// Compiler Explorer: https://godbolt.org/z/4K61vM9hf
int...(5) ...
可以直接创建一个 int
型参数包并展开,obj...
可以直接将结构体的成员转换为参数包。有了这种创建参数包的能力,不但能够简化代码,而且可以直接折叠结构体,极大增强操纵模板参数的能力。
不过,随着 C++26 Pack structure bindings 和 Pack indexing 的加入,在一定程度上能够改善这部分问题。那时 Compile-time for 便无需借助 Lambda 充当辅助函数了,可以直接这样写:
auto [...Is] = std::make_index_sequence<5>{};
((std::println("hello fold expressions"), Is), ...);
这才是最直接的方式,Fold expressions 如今主要就受限于参数包的特性不足。
Ternary Right Fold Expression
Fold expressions 在遍历时,还缺少一种处理错误的方式,看如下例子:
template<std::size_t... Is>
auto test_impl(std::size_t j, std::index_sequence<Is...>)
{
return ((j == Is ? (std::println("found"), true) : 0) || ...);
}
template<std::size_t N>
auto test(std::size_t j)
{
return test_impl(j, std::make_index_sequence<N>{});
}
int main() {
test<5>(5);
}
当遍历查找失败时,Fold expressions 无法处理错误。而 Ternary Right Fold Expression 就是解决这个问题的,代码变成:
template<std::size_t... Is>
auto test_impl(std::size_t j, std::index_sequence<Is...>)
{
return ((j == Is ? std::println("found")
: ... : throw std::range_error("Out of range"));
}
可以看到,这个特性可以消除 ||
和 ,
,进一步简化代码,并且在查找失败时,可以处理错误。
总结一下,它的语法格式是这样的:
( C ? E : ... : D )
展开变为:
( C(1) ? E(1) : ( ... ( C(N-1) ? E(N-1) : ( C(N) ? E(N) : D ) ) ) )
这个特性可能会入 C++26。
Conclusion
本章更为全面而深入地介绍了 Fold Expressions,这是元编程的编译期遍历方式,能够消除传统递归和循环,简化代码,减少重复。
Fold 这个概念从函数式编程而来,是一种抽象等级更高的遍历函数,因此不同于传统面向对象和面向过程范式中的递归和循环遍历,表达起来更加自然。示例代码依旧非常丰富,展示了各种高级技巧,这些技巧用于提供条件、中断、中间值等能力。
同时,Fold Expressions 也有不足之处,主要不足源于参数包特性和错误处理模块的缺失,这将在 C++26 进一步得到解决。