本篇源于群友的某个提问,需求略显特别,却是一个非常不错的产生式元编程例子,遂记录一下。

原始问题

优化后的原始问题,代码如下:

struct base {
    virtual void foo() = 0;
};

template<int I, int J, int K, int L>
struct derived : public base {
    void foo() override final {
        std::cout << "derived<" << I << ',' << J << ',' << K << ',' << L << ">::foo()\n";
    }
};

template<int I, int J, int K, int L>
base* create_instance_impl() {
    return new derived<I, J, K, L>();
}

base* create_instance(int i, int j, int k, int l) {
    // Error, i, j, k and l are not constant expressions.
    return create_instance_impl<i,j,k,l>();
}

int main() {
    create_instance(0,0,0,0)->foo();
    create_instance(0,1,2,3)->foo();
}

简而言之,某些原因下,可能是老项目,也可能是外部依赖,类的原始设计需要编译期的值作为模板参数,而如今这些输入却只能在运行期获得(从文件读、从网络接收……),不能更改原有类的情况下,只有将这些运行期的值转换成编译期值,才能正常调用。

运行期的值如何能转换成编译期?

它们一个发生于运行期,一个发生于编译期,似乎没有办法。但是,本质上还是一种 value-to-type 的问题,在实现 Factory 模式时,便系统探索过这个问题的解决方案,核心思路就是映射,针对每个值,一一提供 value-to-type 的对应关系。

这个问题还隐藏着一个问题,数字之间其实暗含着一层组合关系,假设 i/j/k/l 的最大取值分别为 $$M_i/M_j/M_k/M_l$$,则共有 $$Mi \times Mj \times Mk \times Ml$$ 种组合方式。

倘若按照普通方式映射,实现如下:

base* create_instance(int i, int j, int k, int l) {
    if (i == 0 && j == 0 && k == 0 && l == 0)
        return create_instance_impl<0, 0, 0, 0>();
    else if (i == 0 && j == 0 && k == 0 && l == 1)
        return create_instance_impl<0, 0, 0, 1>();
    else if (i == 0 && j == 0 && k == 0 && l == 2)
        return create_instance_impl<0, 0, 0, 2>();
    // Add more cases for other (i, j, k, l) combinations as needed
}

可见,由于组合数量太多,根本不可能写尽映射。因此,问题的一个关键便在于减少映射数量。

降维法:手动版解决方案

数学中,乘法是一种能够为稀少信息升维的工具,而加法则是一种降维的工具。若是能够将 $$Mi \times Mj \times Mk \times Ml$$ 变成 $$Mi + Mj + Mk + Ml$$,势必会极大地减少映射数量。

怎么将乘法转变为减法呢?

关键在于把立体结构转换成线性结构,当前 i/j/k/l 是一种交织在一起的高维结构,我们需要分散它们,将它们从高维降至一维。

这种降维思路下的映射方式,代码如下:

template<int I, int J, int K>
base* create_instance_of_k(int l) {
    if (l == 0)
        return create_instance_impl<I, J, K, 0>();
    else if (l == 1)
        return create_instance_impl<I, J, K, 1>();
    else if (l == 2)
        return create_instance_impl<I, J, K, 2>();
    else if (l == 3)
        return create_instance_impl<I, J, K, 3>();
    else
        return nullptr;
}

template<int I, int J>
base* create_instance_of_j(int k, int l) {
    if (k == 0)
        return create_instance_of_k<I, J, 0>(l);
    else if (k == 1)
        return create_instance_of_k<I, J, 1>(l);
    else if (k == 2)
        return create_instance_of_k<I, J, 2>(l);
    else
        return nullptr;
}

template<int I>
base* create_instance_of_i(int j, int k, int l) {
    if (j == 0)
        return create_instance_of_j<I, 0>(k, l);
    else if (j == 1)
        return create_instance_of_j<I, 1>(k, l);
    else if (j == 2)
        return create_instance_of_j<I, 2>(k, l);
    else
        return nullptr;
}

base* create_instance(int i, int j, int k, int l) {
    if (i == 0)
        return create_instance_of_i<0>(j, k, l);
    else if (i == 1)
        return create_instance_of_i<1>(j, k, l);
    else if (i == 2)
        return create_instance_of_i<2>(j, k, l);
    else
        return nullptr;
}

此处,$$M_i/M_j/M_k/M_l$$ 分别取 2/2/2/3,所以降维前共有 $$3 \times 3 \times 3 \times 4=108$$ 种映射关系,而降维后仅有 $$3 + 3 + 3 + 4=13$$ 种映射关系。

通过降维法,成功将一个无法穷尽的映射关系减少到能够穷尽的数量,从而解决了本问题。

但需要分清,降维法只是降低了需要手动编写的映射数量,并非是改变组合的数量。value-to-type 必须一一对应,实际上生成的模板特化依旧有 108 种。

因此,问题规模从一开始就需要定好上界,否则模板爆炸式生成,编译时间也会爆炸式增长。

降维法:自动版解决方案

手动版增加映射关系需要写很多代码,可以利用《产生式元编程》中的各种技巧实现自动生成映射关系。

核心实现如下:

/**
 * Converts runtime values to compile-time constants and invokes the callback function `call`.
 * 
 * Template Parameters:
 * - ReturnType: The return type of the callback function `call`.
 * - T: The type of the runtime variables to be converted into compile-time constants.
 * - Is...: A pack of compile-time constants derived from the runtime values.
 * - F: The type of the callback function that will be invoked with the compile-time values.
 * - Args...: The types of the remaining runtime variables to be converted into compile-time constants.
 * - FirstLimit: The maximum value for the first runtime variable.
 * - Limits...: The maximum values for the remaining runtime variables.
 * 
 * Function Parameters:
 * - call: A callable object (e.g., lambda or function) that will be invoked with the compile-time constants.
 * - first: The first runtime variable, which will be converted to a compile-time constant.
 * - args...: The remaining runtime variables, each of which will also be converted into compile-time constants.
 * 
 * Constraints:
 * - The number of runtime variables (Args...) must equal the number of their respective maximum values (Limits...).
 */
template<class ReturnType, class T, T... Is, class F, class... Args, size_t FirstLimit, size_t... Limits>
    requires (sizeof...(Args) == sizeof...(Limits))
ReturnType to_compile_time_values_impl(F&& call, std::index_sequence<FirstLimit, Limits...>,
    T first, Args&&... args) {
    if constexpr (sizeof...(args) == 0) {
        return [first, call = std::forward<F>(call)]<auto... I>(std::index_sequence<I...>) {
            if constexpr (std::is_same_v<ReturnType, void>) {
                ((first == I ? (call(std::index_sequence<Is..., I>{}), true) : false) || ...);
            } else {
                ReturnType result{};
                ((first == I ? (result = call(std::index_sequence<Is..., I>{}), true) : false) || ...);
                return result;
            }
        }(std::make_index_sequence<FirstLimit>{});
    } else {
            if constexpr (std::is_same_v<ReturnType, void>) {
                [first, call = std::forward<F>(call), ...args = std::forward<Args>(args)]<auto... I>
                    (std::index_sequence<I...>) {
                    ((first == I ? (to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
                }(std::make_index_sequence<FirstLimit>{});
            } else {
                return [first, call = std::move(call), ...args = std::forward<Args>(args)]<auto... I>
                    (std::index_sequence<I...>) {
                    ReturnType result{};
                    ((first == I ? (result = to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
                    return result;
                }(std::make_index_sequence<FirstLimit>{});
            }
    }
}

除了满足需求,其中还处理了返回值问题,call 是一个泛型函数或泛型 Lambda,在真正调用之前,无法获取其返回类型,必须得通过模板显式传递该返回类型,而 void 无法声明结果变量,因此分了情况处理。抛开这些为了通用性而做的特殊处理,核心逻辑其实也就不到 20 行。

为了使用方便,外部接口也分了多种情况,实现如下:

/**
 * Limits... represents the maximum values for the corresponding runtime variables.
 * If these values are too large, the number of combinations will increase significantly,
 * potentially causing a template instantiation explosion and greatly extending compilation time.
 */
template<class ReturnType, size_t... Limits, class F, class... Ints>
    requires(sizeof...(Limits) <= sizeof...(Ints))
ReturnType to_compile_time_values(F&& call, Ints&&... ints) {
    if constexpr (sizeof...(Limits) == sizeof...(Ints)) {
        return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
            std::index_sequence<Limits...>{},
            std::forward<Ints>(ints)...);
    } else if constexpr (sizeof...(Limits) == 1) {
        constexpr std::size_t V = std::get<0>(std::tuple(Limits...));
        return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
            make_repeat_index_sequence<sizeof...(Ints), V>{},
            std::forward<Ints>(ints)...);
    } else {
        return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
            make_repeat_index_sequence<sizeof...(Ints), 5>{},
            std::forward<Ints>(ints)...);
    }
}

Limits... 表示每个变量的最大值,若是每个变量的最大值不同,则应该分别指定;若是每个变量的最大值相同,则只需指定一个;若是不指定,默认最大值皆为 5。

如前文所言,倘若 Limits... 的数量很多或是很大,映射的组合形式也会很大,从而导致模板爆炸,极大地增加编译时间。这里所说的“极大”并非是危言耸听,以四位数字为例,假设每位数字皆取 10,则 $$10 \times 10 \times 10 \times 10$$ 共有 10000 种组合结果,编译器需要特化生成如此多的类,时间消耗,可想而知。

此外,make_repeat_index_sequence 是为了生成指定个数的相同索引,实现为:

template<size_t N, size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl
    : make_repeat_index_sequence_impl<N-1, V, V, Vs...>
{};

template<size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl<1, V, Vs...>
    : std::type_identity<std::index_sequence<Vs..., V>>
{};

template<size_t N, size_t V>
using make_repeat_index_sequence = make_repeat_index_sequence_impl<N, V>::type;

简简单单的辅助类,不需赘述。

示例:自动版通用方案

该通用实现有一定的灵活性,可以直接将运行期值转换为编译期。

先来看一个简单的例子:

int i = 1;
int j = 2;
int k = 3;

// Output: 1 2 3
to_compile_time_values<void>([]<auto... Is>(std::index_sequence<Is...>) {
    ((std::cout << Is << " "), ...);
}, i, j, k);

对于原问题,借助该实现解决为:

base* create_instance(int i, int j, int k, int l) {
    return to_compile_time_values<base*>([]<auto... Is>(std::index_sequence<Is...>) {
        return create_instance_impl<Is...>();
    }, i, j, k, l);
}

int main() {
    // Output: derived<0,0,0,0>::foo()
    create_instance(0,0,0,0)->foo();

    // Output: derived<0,1,2,3>::foo()
    create_instance(0,1,2,3)->foo();
}

若是想效率更高一点,最好指定每个运行期值的最大范围,这样使用:

/**
 * Creates an instance of `base*` by converting the runtime values i, j, k, l 
 * into compile-time constants.
 * 
 * The maximum allowed values for the parameters are:
 *  - i: 4
 *  - j: 3
 *  - k: 3
 *  - l: 4
 * 
 * These limits correspond to the template arguments of `to_compile_time_values<base*, 4, 3, 3, 4>`.
 * Exceeding these values may result in undefined behavior or compilation failure.
 */
base* create_instance(int i, int j, int k, int l) {
    return to_compile_time_values<base*, 4, 3, 3, 4>([]<auto... Is>(std::index_sequence<Is...>) {
        return create_instance_impl<Is...>();
    }, i, j, k, l);
}

那么若是 i/j/k/l 的值超过 3/2/2/3(下标从 0 开始,故减 1),将没有任何返回值:

int main() {
    // result is nullptr
    base* result = create_instance(0,1,2,4);
    if (result) {
        result->foo();
    }
}

如果 i/j/k/l 的最大值都是 4,那么无需全部写出,只需写一个便可,如:

base* create_instance(int i, int j, int k, int l) {
    return to_compile_time_values<base*, 4>([]<auto... Is>(std::index_sequence<Is...>) {
        return create_instance_impl<Is...>();
    }, i, j, k, l);
}

int main() {
    // OK
    create_instance(0,0,0,0)->foo();
    create_instance(0,1,2,3)->foo();

这些灵活性,就是代码量增加的原因所在,再小的功能,只要考虑通用性,复杂性就会显著提升。

总结

本文算是一个最佳实践,使用了诸多深入层次的产生式元编程技术,这些技术都在《产生式元编程》中介绍过,因此这里不曾细致讲解实现细节。

涉及到的技巧很多,该实践甚至是一个模板爆炸的绝佳例子,组合数量过大,编译器都会直接编译超时。

同时,本文还介绍了一个利用乘法和加法进行升维降维的思路,也正是这个思路给了实现以理论支撑。

相信本文能让大家学习一个产生式元编程的具体运用,巩固过去所学。

完整实现和示例:

// https://godbolt.org/z/47E6fbqGr
#include <concepts>
#include <iostream>
#include <type_traits>

struct base {
    virtual void foo() = 0;
};

template<int I, int J, int K, int L>
struct derived : public base {
    void foo() override final {
        std::cout << "derived<" << I << ',' << J << ',' << K << ',' << L << ">::foo()\n";
    }
};

template<int I, int J, int K, int L>
base* create_instance_impl() {
    return new derived<I, J, K, L>();
}

template<size_t N, size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl
    : make_repeat_index_sequence_impl<N-1, V, V, Vs...>
{};

template<size_t V, size_t... Vs>
struct make_repeat_index_sequence_impl<1, V, Vs...>
    : std::type_identity<std::index_sequence<Vs..., V>>
{};

template<size_t N, size_t V>
using make_repeat_index_sequence = make_repeat_index_sequence_impl<N, V>::type;

/**
 * Converts runtime values to compile-time constants and invokes the callback function `call`.
 * 
 * Template Parameters:
 * - ReturnType: The return type of the callback function `call`.
 * - T: The type of the runtime variables to be converted into compile-time constants.
 * - Is...: A pack of compile-time constants derived from the runtime values.
 * - F: The type of the callback function that will be invoked with the compile-time values.
 * - Args...: The types of the remaining runtime variables to be converted into compile-time constants.
 * - FirstLimit: The maximum value for the first runtime variable.
 * - Limits...: The maximum values for the remaining runtime variables.
 * 
 * Function Parameters:
 * - call: A callable object (e.g., lambda or function) that will be invoked with the compile-time constants.
 * - first: The first runtime variable, which will be converted to a compile-time constant.
 * - args...: The remaining runtime variables, each of which will also be converted into compile-time constants.
 * 
 * Constraints:
 * - The number of runtime variables (Args...) must equal the number of their respective maximum values (Limits...).
 */
template<class ReturnType, class T, T... Is, class F, class... Args, size_t FirstLimit, size_t... Limits>
    requires (sizeof...(Args) == sizeof...(Limits))
ReturnType to_compile_time_values_impl(F&& call, std::index_sequence<FirstLimit, Limits...>,
    T first, Args&&... args) {
    if constexpr (sizeof...(args) == 0) {
        return [first, call = std::forward<F>(call)]<auto... I>(std::index_sequence<I...>) {
            if constexpr (std::is_same_v<ReturnType, void>) {
                ((first == I ? (call(std::index_sequence<Is..., I>{}), true) : false) || ...);
            } else {
                ReturnType result{};
                ((first == I ? (result = call(std::index_sequence<Is..., I>{}), true) : false) || ...);
                return result;
            }
        }(std::make_index_sequence<FirstLimit>{});
    } else {
            if constexpr (std::is_same_v<ReturnType, void>) {
                [first, call = std::forward<F>(call), ...args = std::forward<Args>(args)]<auto... I>
                    (std::index_sequence<I...>) {
                    ((first == I ? (to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
                }(std::make_index_sequence<FirstLimit>{});
            } else {
                return [first, call = std::move(call), ...args = std::forward<Args>(args)]<auto... I>
                    (std::index_sequence<I...>) {
                    ReturnType result{};
                    ((first == I ? (result = to_compile_time_values_impl<ReturnType, T, Is..., I>(std::move(call), std::index_sequence<Limits...>{}, args...), true) : false) || ...);
                    return result;
                }(std::make_index_sequence<FirstLimit>{});
            }
    }
}

/**
 * Limits... represents the maximum values for the corresponding runtime variables.
 * If these values are too large, the number of combinations will increase significantly,
 * potentially causing a template instantiation explosion and greatly extending compilation time.
 */
template<class ReturnType, size_t... Limits, class F, class... Ints>
    requires(sizeof...(Limits) <= sizeof...(Ints))
ReturnType to_compile_time_values(F&& call, Ints&&... ints) {
    if constexpr (sizeof...(Limits) == sizeof...(Ints)) {
        return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
            std::index_sequence<Limits...>{},
            std::forward<Ints>(ints)...);
    } else if constexpr (sizeof...(Limits) == 1) {
        constexpr std::size_t V = std::get<0>(std::tuple(Limits...));
        return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
            make_repeat_index_sequence<sizeof...(Ints), V>{},
            std::forward<Ints>(ints)...);
    } else {
        return to_compile_time_values_impl<ReturnType>(std::forward<F>(call),
            make_repeat_index_sequence<sizeof...(Ints), 5>{},
            std::forward<Ints>(ints)...);
    }
}

/**
 * Creates an instance of `base*` by converting the runtime values i, j, k, l 
 * into compile-time constants.
 * 
 * The maximum allowed values for the parameters are:
 *  - i: 4
 *  - j: 3
 *  - k: 3
 *  - l: 4
 * 
 * These limits correspond to the template arguments of `to_compile_time_values<base*, 4, 3, 3, 4>`.
 * Exceeding these values may result in undefined behavior or compilation failure.
 */

base* create_instance(int i, int j, int k, int l) {
    return to_compile_time_values<base*, 4, 3, 3, 4>([]<auto... Is>(std::index_sequence<Is...>) {
        return create_instance_impl<Is...>();
    }, i, j, k, l);
}

int main() {
    create_instance(0,0,0,0)->foo();
    create_instance(0,1,2,3)->foo();

    base* result = create_instance(0,1,2,4);
    if (result) {
        result->foo();
    }

    int i = 1;
    int j = 2;
    int k = 3;
    to_compile_time_values<void>([]<auto... Is>(std::index_sequence<Is...>) {
        ((std::cout << Is << " "), ...);
    }, i, j, k);
}

Leave a Reply

Your email address will not be published. Required fields are marked *

You can use the Markdown in the comment form.