Writing PyTorch class in Cython
Asked Answered
C

0

7

I'm trying to find examples of a PyTorch nn.Module class written in Cython for speed but haven't found anything. Suppose I have the below class written in Python, what would the best Cython translation be?

class Actor(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, hidden_size)
        self.l4 = nn.Linear(hidden_size, action_size)
        self.log_std = nn.Parameter(-0.5 * torch.ones(action_size, dtype=torch.float32))

    def forward(self, x):
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        x = torch.relu(self.l3(x))
        mu = self.l4(x)
        return mu

    def dist(self, mu):
        pi = Normal(mu, torch.exp(self.log_std))
        return pi

    def log_prob(self, pi, action):
        return pi.log_prob(action).sum(axis=-1)
Cheng answered 10/3, 2021 at 22:8 Comment(1)
Where Cython helps with speed is mainly for indexing into arrays (which can be slow in Python). Your example class doesn't have any of that, so there's probably little value in writing it in Cython.Ecphonesis

© 2022 - 2024 — McMap. All rights reserved.