I'm trying to find examples of a PyTorch nn.Module
class written in Cython for speed but haven't found anything. Suppose I have the below class written in Python, what would the best Cython translation be?
class Actor(nn.Module):
def __init__(self, state_size, action_size, hidden_size):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_size, hidden_size)
self.l2 = nn.Linear(hidden_size, hidden_size)
self.l3 = nn.Linear(hidden_size, hidden_size)
self.l4 = nn.Linear(hidden_size, action_size)
self.log_std = nn.Parameter(-0.5 * torch.ones(action_size, dtype=torch.float32))
def forward(self, x):
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
x = torch.relu(self.l3(x))
mu = self.l4(x)
return mu
def dist(self, mu):
pi = Normal(mu, torch.exp(self.log_std))
return pi
def log_prob(self, pi, action):
return pi.log_prob(action).sum(axis=-1)