How to define class with overloaded methods for each std::variant alternative?
Asked Answered
A

1

0

I have some std::variant classes, each with several alternatives, and I would like to define a visitor class template that takes a variant as its template parameter and will automatically define a pure virtual void operator()(T const&) const for each alternative T in the variant. This way, I can define subclasses that inherit from instantiations of these visitor template classes, and will be forced to override each method, defined as pure virtual in its respective base class.

e.g.

#include <variant>

using VarA = std::variant<A1, A2, /* ... more alternatives ... */>;
using VarB = std::variant<B1, B2, /* ... more alternatives ... */>;

struct VarAVisitor : Visitor<VarA>
{
    // Must override 'void operator()(T const&) const' for each alternative type 'T' in VarA
};

struct VarBVisitor : Visitor<VarB>
{
    // Must override 'void operator()(T const&) const' for each alternative type 'T' in VarB
};

Basically, I am asking how would I implement the Visitor class template in the above example?

Accessary answered 5/6, 2022 at 18:52 Comment(3)
Is there a reason why you are designing a custom visitor instead of using std::visit()? Then you could use the overload trick to execute a lambda for each alternative.Keratitis
The visitor will be passed as the first argument to std::visit(). I want to define these visitor subclasses that are forced to override the pure virtual methods of their base template classes so that the compiler will produce more readable error messages if I forget a particular alternative type.Accessary
what you are looking for is probably a variation of the "overload" variant trick, but my template kung-fu is weak when it comes to parameter packs, but I don't think I've ever seen any syntax that would allow you to pass in a variant type to a visitor's template argument and break apart its individual alternative types into secondary templates that can generate the needed operator() overloads you are looking for. It only works if you pass the alternatives themselves directly to the visitor's template arguments.Keratitis
A
0

After some some googling and lots of trial and error, I managed to come up with something that does what I want. I'm sharing the solution here for anyone else who comes across the same issue.

Here is a proof of concept.

#include <iostream>
#include <variant>


template <typename> class Test { };

using Foo = std::variant<
    Test<struct A>,
    Test<struct B>,
    Test<struct C>,
    Test<struct D>
    >;

using Bar = std::variant<
    Test<struct E>,
    Test<struct F>,
    Test<struct G>,
    Test<struct H>,
    Test<struct I>,
    Test<struct J>,
    Test<struct K>,
    Test<struct L>
    >;


template <typename T>
struct DefineVirtualFunctor
{
    virtual int operator()(T const&) const = 0;
};

template <template <typename> typename Modifier, typename... Rest>
struct ForEach { };
template <template <typename> typename Modifier, typename T, typename... Rest>
struct ForEach<Modifier, T, Rest...> : Modifier<T>, ForEach<Modifier, Rest...> { };

template <typename Variant>
struct Visitor;
template <typename... Alts>
struct Visitor<std::variant<Alts...>> : ForEach<DefineVirtualFunctor, Alts...> { };


struct FooVisitor final : Visitor<Foo>
{
    int operator()(Test<A> const&) const override { return  0; }
    int operator()(Test<B> const&) const override { return  1; }
    int operator()(Test<C> const&) const override { return  2; }
    int operator()(Test<D> const&) const override { return  3; }
};

struct BarVisitor final : Visitor<Bar>
{
    int operator()(Test<E> const&) const override { return  4; }
    int operator()(Test<F> const&) const override { return  5; }
    int operator()(Test<G> const&) const override { return  6; }
    int operator()(Test<H> const&) const override { return  7; }
    int operator()(Test<I> const&) const override { return  8; }
    int operator()(Test<J> const&) const override { return  9; }
    int operator()(Test<K> const&) const override { return 10; }
    int operator()(Test<L> const&) const override { return 11; }
};


int main(int argc, char const* argv[])
{
    Foo foo;
    Bar bar;
    
    switch (argc) {
    case  0: foo = Foo{ std::in_place_index<0> }; break;
    case  1: foo = Foo{ std::in_place_index<1> }; break;
    case  2: foo = Foo{ std::in_place_index<2> }; break;
    default: foo = Foo{ std::in_place_index<3> }; break;
    }
    switch (argc) {
    case  0: bar = Bar{ std::in_place_index<0> }; break;
    case  1: bar = Bar{ std::in_place_index<1> }; break;
    case  2: bar = Bar{ std::in_place_index<2> }; break;
    case  3: bar = Bar{ std::in_place_index<3> }; break;
    case  4: bar = Bar{ std::in_place_index<4> }; break;
    case  5: bar = Bar{ std::in_place_index<5> }; break;
    case  6: bar = Bar{ std::in_place_index<6> }; break;
    default: bar = Bar{ std::in_place_index<7> }; break;
    }
    
    std::cout << std::visit(FooVisitor{ }, foo) << "\n";
    std::cout << std::visit(BarVisitor{ }, bar) << "\n";

    return 0;
}

As you can see, the Visitor class template accepts a std::variant type as a template parameter, from which it will define an interface that must be implemented in any child classes that inherit from the template class instantiation. If, in a child class, you happen to forget to override one of the pure virtual methods, you will get an error like the following.

$ g++ -std=c++17 -o example example.cc
example.cc: In function ‘int main(int, const char**)’:
example.cc:87:41: error: invalid cast to abstract class type ‘BarVisitor’
   87 |     std::cout << std::visit(BarVisitor{ }, bar) << "\n";
      |                                         ^
example.cc:51:8: note:   because the following virtual functions are pure within ‘BarVisitor’:
   51 | struct BarVisitor final : Visitor<Bar>
      |        ^~~~~~~~~~
example.cc:29:17: note:     ‘int DefineVirtualFunctor<T>::operator()(const T&) const [with T = Test<J>]’
   29 |     virtual int operator()(T const&) const = 0;
      |                 ^~~~~~~~

This is much easier to understand than the error messages that the compiler usually generates when using std::visit().

Accessary answered 5/6, 2022 at 21:3 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.