Pytorch custom activation functions?
Asked Answered
L

3

15

I'm having issues with implementing custom activation functions in Pytorch, such as Swish. How should I go about implementing and using custom activation functions in Pytorch?

Loo answered 19/4, 2019 at 17:0 Comment(0)
S
25

There are four possibilities depending on what you are looking for. You will need to ask yourself two questions:

Q1) Will your activation function have learnable parameters?

If yes, you have no choice but to create your activation function as an nn.Module class because you need to store those weights.

If no, you are free to simply create a normal function, or a class, depending on what is convenient for you.

Q2) Can your activation function be expressed as a combination of existing PyTorch functions?

If yes, you can simply write it as a combination of existing PyTorch function and won't need to create a backward function which defines the gradient.

If no you will need to write the gradient by hand.

Example 1: SiLU function

The SiLU function f(x) = x * sigmoid(x) does not have any learned weights and can be written entirely with existing PyTorch functions, thus you can simply define it as a function:

def silu(x):
    return x * torch.sigmoid(x)

and then simply use it as you would have torch.relu or any other activation function.

Example 2: SiLU with learned slope

In this case you have one learned parameter, the slope, thus you need to make a class of it.

class LearnedSiLU(nn.Module):
    def __init__(self, slope = 1):
        super().__init__()
        self.slope = slope * torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        return self.slope * x * torch.sigmoid(x)

Example 3: with backward

If you have something for which you need to create your own gradient function, you can look at this example: Pytorch: define custom function

Sectionalism answered 12/7, 2019 at 19:31 Comment(2)
Did you mean: "...If yes, you have no choice BUT to create your activation function as an nn.Module class because you need to store those weights."Indignation
yes that is what i meantSectionalism
A
2

You can write a customized activation function like below (e.g. weighted Tanh).

class weightedTanh(nn.Module):
    def __init__(self, weights = 1):
        super().__init__()
        self.weights = weights

    def forward(self, input):
        ex = torch.exp(2*self.weights*input)
        return (ex-1)/(ex+1)

Don’t bother about backpropagation if you use autograd compatible operations.

Agriculture answered 19/4, 2019 at 17:43 Comment(3)
Thanks, but it gives me a -main_.Swish is not a Module subclass error.Loo
@Loo please share your code by updating your question.Agriculture
Issue has been solved over on discuss.pytorch.org/t/custom-activation-functions/43055Loo
M
1

I wrote the following SinActivation sub-class of nn.Module to implement the sin activation function.

class SinActivation(torch.nn.Module):
    def __init__(self):
        super(SinActivation, self).__init__()
        return
    def forward(self, x):
        return torch.sin(x)
Merganser answered 28/12, 2022 at 8:52 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.