Clean Conversion from Run-time Values to Compile-time Constants in C++
本篇源于群友的某个提问,需求略显特别,却是一个非常不错的产生式元编程例子,遂记录一下。
原始问题
优化后的原始问题,代码如下:
struct base {
virtual void foo() = 0;
};
template<int I, int J, int K, int L>
struct derived : public base {
void foo() override final {
std::cout << "derived<" << I << ',' << J << ',' << K << ',' << L << ">::foo()\n";
}
};
template<int I, int J, int K, int L>
base* create_instance_impl() {
return new derived<I, J, K, L>();
}
base* create_instance(int i, int j, int k, int l) {
// Error, i, j, k and l are not constant expressions.
return create_instance_impl<i,j,k,l>();
}
int main() {
create_instance(0,0,0,0)->foo();
create_instance(0,1,2,3)->foo();
}
简而言之,某些原因下,可能是老项目,也可能是外部依赖,类的原始设计需要编译期的值作为模板参数,而如今这些输入却只能在运行期获得(从文件读、从网络接收……),不能更改原有类的情况下,只有将这些运行期的值转换成编译期值,才能正常调用。
运行期的值如何能转换成编译期?
它们一个发生于运行期,一个发生于编译期,似乎没有办法。但是,本质上还是一种 value-to-type 的问题,在实现 Factory 模式时,便系统探索过这个问题的解决方案,核心思路就是映射,针对每个值,一一提供 value-to-type 的对应关系。
这个问题还隐藏着一个问题,数字之间其实暗含着一层组合关系,假设 i/j/k/l
的最大取值分别为 $$M_i/M_j/M_k/M_l$$,则共有 $$Mi \times Mj \times Mk \times Ml$$ 种组合方式。
倘若按照普通方式映射,实现如下:
base* create_instance(int i, int j, int k, int l) {
if (i == 0 && j == 0 && k == 0 && l == 0)
return create_instance_impl<0, 0, 0, 0>();
else if (i == 0 && j == 0 && k == 0 && l == 1)
return create_instance_impl<0, 0, 0, 1>();
else if (i == 0 && j == 0 && k == 0 && l == 2)
return create_instance_impl<0, 0, 0, 2>();
// Add more cases for other (i, j, k, l) combinations as needed
}
可见,由于组合数量太多,根本不可能写尽映射。因此,问题的一个关键便在于减少映射数量。
降维法:手动版解决方案
数学中,乘法是一种能够为稀少信息升维的工具,而加法则是一种降维的工具。若是能够将 $$Mi \times Mj \times Mk \times Ml$$ 变成 $$Mi + Mj + Mk + Ml$$,势必会极大地减少映射数量。
怎么将乘法转变为减法呢?
关键在于把立体结构转换成线性结构,当前 i/j/k/l
是一种交织在一起的高维结构,我们需要分散它们,将它们从高维降至一维。
这种降维思路下的映射方式,代码如下:
template<int I, int J, int K>
base* create_instance_of_k(int l) {
if (l == 0)
return create_instance_impl<I, J, K, 0>();
else if (l == 1)
return create_instance_impl<I, J, K, 1>();
else if (l == 2)
return create_instance_impl<I, J, K, 2>();
else if (l == 3)
return create_instance_impl<I, J, K, 3>();
else
return nullptr;
}
template<int I, int J>
base* create_instance_of_j(int k, int l) {
if (k == 0)
return create_instance_of_k<I, J, 0>(l);
else if (k == 1)
return create_instance_of_k<I, J, 1>(l);
else if (k == 2)
return create_instance_of_k<I, J, 2>(l);
else
return nullptr;
}
template<int I>
base* create_instance_of_i(int j, int k, int l) {
if (j == 0)
return create_instance_of_j<I, 0>(k, l);
else if (j == 1)
return create_instance_of_j<I, 1>(k, l);
else if (j == 2)
return create_instance_of_j<I, 2>(k, l);
else
return nullptr;
}
base* create_instance(int i, int j, int k, int l) {
if (i == 0)
return create_instance_of_i<0>(j, k, l);
else if (i == 1)
return create_instance_of_i<1>(j, k, l);
else if (i == 2)
return create_instance_of_i<2>(j, k, l);
else
return nullptr;
}
此处,$$M_i/M_j/M_k/M_l$$ 分别取 2/2/2/3
,所以降维前共有 $$3 \times 3 \times 3 \times 4=108$$ 种映射关系,而降维后仅有 $$3 + 3 + 3 + 4=13$$ 种映射关系。
通过降维法,成功将一个无法穷尽的映射关系减少到能够穷尽的数量,从而解决了本问题。
但需要分清,降维法只是降低了需要手动编写的映射数量,并非是改变组合的数量。value-to-type 必须一一对应,实际上生成的模板特化依旧有 108 种。
因此,问题规模从一开始就需要定好上界,否则模板爆炸式生成,编译时间也会爆炸式增长。
降维法:自动版解决方案
手动版增加映射关系需要写很多代码,可以利用《产生式元编程》中的各种技巧实现自动生成映射关系。
核心实现如下:
/**
* Converts runtime values to compile-time constants and invokes the callback function `call`.
*
* Template Parameters:
* - ReturnType: The return type of the callback function `call`.
* - T: The type of the runtime variables to be converted into compile-time constants.
* - Is...: A pack of compile-time constants derived from the runtime values.
* - F: The type of the callback function that will be invoked with the compile-time values.
* - Args...: The types of the remaining runtime variables to be converted into compile-time constants.
* - FirstLimit: The maximum value for the first runtime variable.
* - Limits...: The maximum values for the remaining runtime variables.
*
* Function Parameters:
* - call: A callable object (e.g., lambda or function) that will be invoked with the compile-time constants.
* - first: The first runtime variable, which will be converted to a compile-time constant.
* - args...: The remaining runtime variables, each of which will also be converted into compile-time constants.
*
* Constraints:
* - The number of runtime variables (Args...) must equal the number of their respective maximum values (Limits...).
*/
template<class ReturnType, class T, T... Is, class F, class... Args, size_t FirstLimit, size_t... Limits>
requires (sizeof...(Args) == sizeof...(Limits))
ReturnType to_compile_time_values_impl(F&& call, std::index_sequence<FirstLimit, Limits...>,
T first, Args&&... args) {
if constexpr (sizeof...(args) == 0) {
return [first, call = std::forward<F>(call)]<auto... I>(std::index_sequence<I...>) {
if constexpr (std::is_same_v<ReturnType, void>) {
((first == I ? (call(std::index_sequence<Is..., I>{}), true) : false) || ...);
} else {
ReturnType result{};
((first == I ? (result = call(std::index_sequence<Is..., I>{}), true) : false) || ...);
return result;
}
}(std::make_index_sequence<FirstLimit>{});
} else {
if constexpr (std::is_same_v<ReturnType, void>) {
[first, call = std::forward<F>(call), ...args = std::forward<Args>(args)]<auto... I>
(std::index_sequence<I...>) {
((first == I ? (to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
}(std::make_index_sequence<FirstLimit>{});
} else {
return [first, call = std::move(call), ...args = std::forward<Args>(args)]<auto... I>
(std::index_sequence<I...>) {
ReturnType result{};
((first == I ? (result = to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
return result;
}(std::make_index_sequence<FirstLimit>{});
}
}
}
除了满足需求,其中还处理了返回值问题,call
是一个泛型函数或泛型 Lambda,在真正调用之前,无法获取其返回类型,必须得通过模板显式传递该返回类型,而 void
无法声明结果变量,因此分了情况处理。抛开这些为了通用性而做的特殊处理,核心逻辑其实也就不到 20 行。
为了使用方便,外部接口也分了多种情况,实现如下:
/**
* Limits... represents the maximum values for the corresponding runtime variables.
* If these values are too large, the number of combinations will increase significantly,
* potentially causing a template instantiation explosion and greatly extending compilation time.
*/
template<class ReturnType, size_t... Limits, class F, class... Ints>
requires(sizeof...(Limits) <= sizeof...(Ints))
ReturnType to_compile_time_values(F&& call, Ints&&... ints) {
if constexpr (sizeof...(Limits) == sizeof...(Ints)) {
return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
std::index_sequence<Limits...>{},
std::forward<Ints>(ints)...);
} else if constexpr (sizeof...(Limits) == 1) {
constexpr std::size_t V = std::get<0>(std::tuple(Limits...));
return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
make_repeat_index_sequence<sizeof...(Ints), V>{},
std::forward<Ints>(ints)...);
} else {
return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
make_repeat_index_sequence<sizeof...(Ints), 5>{},
std::forward<Ints>(ints)...);
}
}
Limits...
表示每个变量的最大值,若是每个变量的最大值不同,则应该分别指定;若是每个变量的最大值相同,则只需指定一个;若是不指定,默认最大值皆为 5。
如前文所言,倘若 Limits...
的数量很多或是很大,映射的组合形式也会很大,从而导致模板爆炸,极大地增加编译时间。这里所说的“极大”并非是危言耸听,以四位数字为例,假设每位数字皆取 10,则 $$10 \times 10 \times 10 \times 10$$ 共有 10000 种组合结果,编译器需要特化生成如此多的类,时间消耗,可想而知。
此外,make_repeat_index_sequence
是为了生成指定个数的相同索引,实现为:
template<size_t N, size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl
: make_repeat_index_sequence_impl<N-1, V, V, Vs...>
{};
template<size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl<1, V, Vs...>
: std::type_identity<std::index_sequence<Vs..., V>>
{};
template<size_t N, size_t V>
using make_repeat_index_sequence = make_repeat_index_sequence_impl<N, V>::type;
简简单单的辅助类,不需赘述。
示例:自动版通用方案
该通用实现有一定的灵活性,可以直接将运行期值转换为编译期。
先来看一个简单的例子:
int i = 1;
int j = 2;
int k = 3;
// Output: 1 2 3
to_compile_time_values<void>([]<auto... Is>(std::index_sequence<Is...>) {
((std::cout << Is << " "), ...);
}, i, j, k);
对于原问题,借助该实现解决为:
base* create_instance(int i, int j, int k, int l) {
return to_compile_time_values<base*>([]<auto... Is>(std::index_sequence<Is...>) {
return create_instance_impl<Is...>();
}, i, j, k, l);
}
int main() {
// Output: derived<0,0,0,0>::foo()
create_instance(0,0,0,0)->foo();
// Output: derived<0,1,2,3>::foo()
create_instance(0,1,2,3)->foo();
}
若是想效率更高一点,最好指定每个运行期值的最大范围,这样使用:
/**
* Creates an instance of `base*` by converting the runtime values i, j, k, l
* into compile-time constants.
*
* The maximum allowed values for the parameters are:
* - i: 4
* - j: 3
* - k: 3
* - l: 4
*
* These limits correspond to the template arguments of `to_compile_time_values<base*, 4, 3, 3, 4>`.
* Exceeding these values may result in undefined behavior or compilation failure.
*/
base* create_instance(int i, int j, int k, int l) {
return to_compile_time_values<base*, 4, 3, 3, 4>([]<auto... Is>(std::index_sequence<Is...>) {
return create_instance_impl<Is...>();
}, i, j, k, l);
}
那么若是 i/j/k/l
的值超过 3/2/2/3
(下标从 0 开始,故减 1),将没有任何返回值:
int main() {
// result is nullptr
base* result = create_instance(0,1,2,4);
if (result) {
result->foo();
}
}
如果 i/j/k/l
的最大值都是 4
,那么无需全部写出,只需写一个便可,如:
base* create_instance(int i, int j, int k, int l) {
return to_compile_time_values<base*, 4>([]<auto... Is>(std::index_sequence<Is...>) {
return create_instance_impl<Is...>();
}, i, j, k, l);
}
int main() {
// OK
create_instance(0,0,0,0)->foo();
create_instance(0,1,2,3)->foo();
这些灵活性,就是代码量增加的原因所在,再小的功能,只要考虑通用性,复杂性就会显著提升。
总结
本文算是一个最佳实践,使用了诸多深入层次的产生式元编程技术,这些技术都在《产生式元编程》中介绍过,因此这里不曾细致讲解实现细节。
涉及到的技巧很多,该实践甚至是一个模板爆炸的绝佳例子,组合数量过大,编译器都会直接编译超时。
同时,本文还介绍了一个利用乘法和加法进行升维降维的思路,也正是这个思路给了实现以理论支撑。
相信本文能让大家学习一个产生式元编程的具体运用,巩固过去所学。
完整实现和示例:
// https://godbolt.org/z/47E6fbqGr
#include <concepts>
#include <iostream>
#include <type_traits>
struct base {
virtual void foo() = 0;
};
template<int I, int J, int K, int L>
struct derived : public base {
void foo() override final {
std::cout << "derived<" << I << ',' << J << ',' << K << ',' << L << ">::foo()\n";
}
};
template<int I, int J, int K, int L>
base* create_instance_impl() {
return new derived<I, J, K, L>();
}
template<size_t N, size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl
: make_repeat_index_sequence_impl<N-1, V, V, Vs...>
{};
template<size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl<1, V, Vs...>
: std::type_identity<std::index_sequence<Vs..., V>>
{};
template<size_t N, size_t V>
using make_repeat_index_sequence = make_repeat_index_sequence_impl<N, V>::type;
/**
* Converts runtime values to compile-time constants and invokes the callback function `call`.
*
* Template Parameters:
* - ReturnType: The return type of the callback function `call`.
* - T: The type of the runtime variables to be converted into compile-time constants.
* - Is...: A pack of compile-time constants derived from the runtime values.
* - F: The type of the callback function that will be invoked with the compile-time values.
* - Args...: The types of the remaining runtime variables to be converted into compile-time constants.
* - FirstLimit: The maximum value for the first runtime variable.
* - Limits...: The maximum values for the remaining runtime variables.
*
* Function Parameters:
* - call: A callable object (e.g., lambda or function) that will be invoked with the compile-time constants.
* - first: The first runtime variable, which will be converted to a compile-time constant.
* - args...: The remaining runtime variables, each of which will also be converted into compile-time constants.
*
* Constraints:
* - The number of runtime variables (Args...) must equal the number of their respective maximum values (Limits...).
*/
template<class ReturnType, class T, T... Is, class F, class... Args, size_t FirstLimit, size_t... Limits>
requires (sizeof...(Args) == sizeof...(Limits))
ReturnType to_compile_time_values_impl(F&& call, std::index_sequence<FirstLimit, Limits...>,
T first, Args&&... args) {
if constexpr (sizeof...(args) == 0) {
return [first, call = std::forward<F>(call)]<auto... I>(std::index_sequence<I...>) {
if constexpr (std::is_same_v<ReturnType, void>) {
((first == I ? (call(std::index_sequence<Is..., I>{}), true) : false) || ...);
} else {
ReturnType result{};
((first == I ? (result = call(std::index_sequence<Is..., I>{}), true) : false) || ...);
return result;
}
}(std::make_index_sequence<FirstLimit>{});
} else {
if constexpr (std::is_same_v<ReturnType, void>) {
[first, call = std::forward<F>(call), ...args = std::forward<Args>(args)]<auto... I>
(std::index_sequence<I...>) {
((first == I ? (to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
}(std::make_index_sequence<FirstLimit>{});
} else {
return [first, call = std::move(call), ...args = std::forward<Args>(args)]<auto... I>
(std::index_sequence<I...>) {
ReturnType result{};
((first == I ? (result = to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
return result;
}(std::make_index_sequence<FirstLimit>{});
}
}
}
/**
* Limits... represents the maximum values for the corresponding runtime variables.
* If these values are too large, the number of combinations will increase significantly,
* potentially causing a template instantiation explosion and greatly extending compilation time.
*/
template<class ReturnType, size_t... Limits, class F, class... Ints>
requires(sizeof...(Limits) <= sizeof...(Ints))
ReturnType to_compile_time_values(F&& call, Ints&&... ints) {
if constexpr (sizeof...(Limits) == sizeof...(Ints)) {
return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
std::index_sequence<Limits...>{},
std::forward<Ints>(ints)...);
} else if constexpr (sizeof...(Limits) == 1) {
constexpr std::size_t V = std::get<0>(std::tuple(Limits...));
return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
make_repeat_index_sequence<sizeof...(Ints), V>{},
std::forward<Ints>(ints)...);
} else {
return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
make_repeat_index_sequence<sizeof...(Ints), 5>{},
std::forward<Ints>(ints)...);
}
}
/**
* Creates an instance of `base*` by converting the runtime values i, j, k, l
* into compile-time constants.
*
* The maximum allowed values for the parameters are:
* - i: 4
* - j: 3
* - k: 3
* - l: 4
*
* These limits correspond to the template arguments of `to_compile_time_values<base*, 4, 3, 3, 4>`.
* Exceeding these values may result in undefined behavior or compilation failure.
*/
base* create_instance(int i, int j, int k, int l) {
return to_compile_time_values<base*, 4, 3, 3, 4>([]<auto... Is>(std::index_sequence<Is...>) {
return create_instance_impl<Is...>();
}, i, j, k, l);
}
int main() {
create_instance(0,0,0,0)->foo();
create_instance(0,1,2,3)->foo();
base* result = create_instance(0,1,2,4);
if (result) {
result->foo();
}
int i = 1;
int j = 2;
int k = 3;
to_compile_time_values<void>([]<auto... Is>(std::index_sequence<Is...>) {
((std::cout << Is << " "), ...);
}, i, j, k);
}