I have two TMP versions. Which one is better, depends on the data types, I guess:
Solution A:
First, let's find a good offset for the split point (powers of two seem great):
template<std::ptrdiff_t diff, std::ptrdiff_t V = 2>
struct offset
{
static constexpr std::ptrdiff_t value =
(V * 2 < diff - 1) ? offset<diff, V * 2>::value : V;
};
// End recursion
template<std::ptrdiff_t diff>
struct offset<diff, 1<<16>
{
static constexpr std::ptrdiff_t value = 1<<16;
};
// Some special cases
template<>
struct offset<0, 2>
{
static constexpr std::ptrdiff_t value = 0;
};
template<>
struct offset<1, 2>
{
static constexpr std::ptrdiff_t value = 0;
};
template<>
struct offset<2, 2>
{
static constexpr std::ptrdiff_t value = 0;
};
With this, we can create a recursive TMP version:
template <std::ptrdiff_t diff, class It, class Func>
auto binary_fold_tmp(It begin, It end, Func op)
-> decltype(op(*begin, *end))
{
assert(end - begin == diff);
switch (diff)
{
case 0:
assert(false);
return 0; // This will never happen
case 1:
return *begin;
case 2:
return op(*begin, *(begin + 1));
default:
{ // first round to the nearest multiple of 2 and then advance
It mid{begin};
std::advance(mid, offset<diff>::value);
auto left = binary_fold_tmp<offset<diff>::value>(begin, mid, op);
auto right =
binary_fold_tmp<diff - offset<diff>::value>(mid, end, op);
return op(left, right);
}
}
}
This can be combined with a non-TMP version like this, for instance:
template <class It, class Func>
auto binary_fold(It begin, It end, Func op)
-> decltype(op(*begin, *end))
{
const auto diff = end - begin;
assert(diff > 0);
switch (diff)
{
case 1:
return binary_fold_tmp<1>(begin, end, op);
case 2:
return binary_fold_tmp<2>(begin, end, op);
case 3:
return binary_fold_tmp<3>(begin, end, op);
case 4:
return binary_fold_tmp<4>(begin, end, op);
case 5:
return binary_fold_tmp<5>(begin, end, op);
case 6:
return binary_fold_tmp<6>(begin, end, op);
case 7:
return binary_fold_tmp<7>(begin, end, op);
case 8:
return binary_fold_tmp<8>(begin, end, op);
default:
if (diff < 16)
return op(binary_fold_tmp<8>(begin, begin + 8, op),
binary_fold(begin + 8, end, op));
else if (diff < 32)
return op(binary_fold_tmp<16>(begin, begin + 16, op),
binary_fold(begin + 16, end, op));
else
return op(binary_fold_tmp<32>(begin, begin + 32, op),
binary_fold(begin + 32, end, op));
}
}
Solution B:
This calculates the pair-wise results, stores them in a buffer, and then calls itself with the buffer:
template <std::ptrdiff_t diff, class It, class Func, size_t... Is>
auto binary_fold_pairs_impl(It begin,
It end,
Func op,
const std::index_sequence<Is...>&)
-> decltype(op(*begin, *end))
{
std::decay_t<decltype(*begin)> pairs[diff / 2] = {
op(*(begin + 2 * Is), *(begin + 2 * Is + 1))...};
if (diff == 2)
return pairs[0];
else
return binary_fold_pairs_impl<diff / 2>(
&pairs[0],
&pairs[0] + diff / 2,
op,
std::make_index_sequence<diff / 4>{});
}
template <std::ptrdiff_t diff, class It, class Func>
auto binary_fold_pairs(It begin, It end, Func op) -> decltype(op(*begin, *end))
{
return binary_fold_pairs_impl<diff>(
begin, end, op, std::make_index_sequence<diff / 2>{});
}
This template function requires diff
to be a power of 2. But of course you can combine it with a non-template version, again:
template <class It, class Func>
auto binary_fold_mix(It begin, It end, Func op) -> decltype(op(*begin, *end))
{
const auto diff = end - begin;
assert(diff > 0);
switch (diff)
{
case 1:
return *begin;
case 2:
return binary_fold_pairs<2>(begin, end, op);
case 3:
return op(binary_fold_pairs<2>(begin, begin + 1, op),
*(begin + (diff - 1)));
case 4:
return binary_fold_pairs<4>(begin, end, op);
case 5:
return op(binary_fold_pairs<4>(begin, begin + 4, op),
*(begin + (diff - 1)));
case 6:
return op(binary_fold_pairs<4>(begin, begin + 4, op),
binary_fold_pairs<4>(begin + 4, begin + 6, op));
case 7:
return op(binary_fold_pairs<4>(begin, begin + 4, op),
binary_fold_mix(begin + 4, begin + 7, op));
case 8:
return binary_fold_pairs<8>(begin, end, op);
default:
if (diff <= 16)
return op(binary_fold_pairs<8>(begin, begin + 8, op),
binary_fold_mix(begin + 8, end, op));
else if (diff <= 32)
return op(binary_fold_pairs<16>(begin, begin + 16, op),
binary_fold_mix(begin + 16, end, op));
else
return op(binary_fold_pairs<32>(begin, begin + 32, op),
binary_fold_mix(begin + 32, end, op));
}
}
I measured with the same program as MtRoad. On my machine, the differences are not as big as MtRoad reported. With -O3
solutions A and B seem to be slightly faster than MtRoad's version, but in reality, you need to test with your types and data.
Remark: I did not test my versions too rigorously.