Inheritance and std::shared_ptr in Cython
Asked Answered
S

2

7

Suppose I have the following simple example of C++ inheritance in file.h:

class Base {};
class Derived : public Base {};

Then, the following code compiles; that is, I can assign std::shared_ptr<Derived> to std::shared_ptr<Base>:

Derived* foo = new Derived();
std::shared_ptr<Derived> shared_foo = std::make_shared<Derived>(*foo);
std::shared_ptr<Base> bar = shared_foo;

Let's also say I've added the types to a decl.pxd:

cdef extern from "file.h":
    cdef cppclass Base:
        pass
    cdef cppclass Derived(Base):
        pass

Then, what I'm trying to do is mimic the above C++ assignment in Cython in a file.pyx:

cimport decl
from libcpp.memory cimport make_shared, shared_ptr

def do_stuff():
    cdef decl.Derived* foo = new decl.Derived()
    cdef shared_ptr[decl.Derived] shared_foo = make_shared[decl.Derived](foo)
    cdef shared_ptr[decl.Base] bar = shared_foo

Unlike the C++ case, this now fails with the following error (using Cython 3.0a6):

cdef shared_ptr[decl.Base] bar = shared_foo
                                ^
---------------------------------------------------------------

 Cannot assign type 'shared_ptr[Derived]' to 'shared_ptr[Base]'

Should I expect this behavior? Is there any way to mimic what the C++ examples does with Cython?

Edit: Cf. the comments to the accepted answer below, the relevant functionality has been added to Cython and is available as of version 3.0a7.

Snobbery answered 20/5, 2021 at 18:39 Comment(0)
P
8

It should work for Cython>=3.0 as @fuglede made this PR fixing the issue described below (which is still present for Cython<3.0).


The issue is, that the the wrapper of std::shared_ptr misses

template <class U> shared_ptr& operator= (const shared_ptr<U>& x) noexcept;

of the std::shared_ptr-class.

If you patch the wrapper like that:

cdef extern from "<memory>" namespace "std" nogil:
cdef cppclass shared_ptr[T]:
    ...
    shared_ptr[T]& operator=[Y](const shared_ptr[Y]& ptr)
    #shared_ptr[Y](shared_ptr[Y]&)  isn't accepted

your code will compile.

You might ask, why operator= and not constructor shared_ptr[Y] is needed, because:

...
cdef shared_ptr[decl.Base] bar = shared_foo

looks like constructor (template <class U> shared_ptr (const shared_ptr<U>& x) noexcept;) is not explicit. But it is one of Cython's quirks with C++. The above code will be translated to

std::shared_ptr<Base> __pyx_v_bar;
...
__pyx_v_bar = __pyx_v_shared_foo;

and not

std::shared_ptr<Base> __pyx_v_bar = __pyx_v_shared_foo;

thus Cython will check the existence of operator= (lucky for us, because Cython seems not to support constructor with templates, but does so for operators).


If you want to distribute your module also on systems without patched memory.pxd you have two option:

  1. to wrap std::shared_ptr correctly by yourself
  2. write a small utility function, for example
%%cython
...
cdef extern from *:
    """
    template<typename T1, typename T2>
    void assign_shared_ptr(std::shared_ptr<T1>& lhs, const std::shared_ptr<T2>& rhs){
         lhs = rhs;
    }
    """
    void assign_shared_ptr[T1, T2](shared_ptr[T1]& lhs, shared_ptr[T2]& rhs)
    
...
cdef shared_ptr[Derived] shared_foo
# cdef shared_ptr[decl.Base] bar = shared_foo
# must be replaced through:
cdef shared_ptr[Base] bar 
assign_shared_ptr(bar, shared_foo)
...

Both options have drawbacks, so depending on your scenario you might prefer one or another.

Pawsner answered 21/5, 2021 at 5:23 Comment(3)
The relevant declaration has been added in this pull request and should be available at some point in the future.Snobbery
@Snobbery thanks for your PR! Sorry I've messed the signature - not sure how it happened and why it worked with cython...Pawsner
Certainly curious enough that it worked either way. In any case, the PR has been merged and the functionality is available in Cython 3.0a7, so thanks for the answer and the pointers!Snobbery
C
2

I have not tried Cyton, but std::shared_ptr has an static cast function std::static_pointer_cast. I think this will work

std::shared_ptr<Base> bar = std::static_pointer_cast<Base>(shared_foo);

.

def do_stuff():
    cdef decl.Derived* foo = new decl.Derived()
    cdef shared_ptr[decl.Derived] shared_foo = make_shared[decl.Derived](foo)
    cdef shared_ptr[decl.Base] bar = static_pointer_cast[decl.Base] shared_foo

As a side note

The way you create shared_foo is probably not what you want. Here you are first creating a dynamically allocated Derived. Then you are creating a new dynamically allocated shared derived that is a copy of the original.

// this allocates one Derived
Derived* foo = new Derived(); 
// This allocates a new copy, it does not take ownership of foo
std::shared_ptr<Derived> shared_foo = std::make_shared<Derived>(*foo); 

What you probably want is either:

Derived* foo = new Derived();
std::shared_ptr<Derived> shared_foo(foo); // This now takes ownership of foo

Or just:

// This makes a default constructed shared Derived
auto shared_foo = std::make_shared<Derived>(); 
Cuddle answered 20/5, 2021 at 20:49 Comment(2)
Good point. To my (limited) knowledge, there's no equivalent of shared_foo(foo); in Cython, meaning that the generated code might run into the same issue(?)Snobbery
And you're right, static_pointer_cast[T, U] is a thing, so indeed, static_pointer_cast[decl.Base, decl.Derived](shared_foo) seems to do what I want it to!Snobbery

© 2022 - 2024 — McMap. All rights reserved.