Custom gradients for Flux rather than using Zygote A.D
Asked Answered
F

1

5

I have a machine learning model where the gradients for the model parameters are analytic and there is no need for automatic differentiation. However, I still want to be able to take advantage of different optimizers in Flux without having to rely on Zygote for the differentiation. Here is some snippets of my code.

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

opt = ADAM(0.01)

I then have a function that computes analytic gradients of my model parameters, θ.

function gradients(x) # x = one input data point or a batch of input data points
    # stuff to calculate gradients of each parameter
    # returns gradients of each parameter

I then want to be able to do something like the following.

grads = gradients(x)
update!(opt, θ, grads)

My question is: What form/type does my gradient(x) function need to return in order to do update!(opt, θ, grads), and how do I do this?

Frauenfeld answered 16/4, 2020 at 16:31 Comment(0)
S
6

If you don't use Params then grads just needs to be the gradient. The only requirement is that θ and grads are the same size.

For example, map((x, g) -> update!(opt, x, g), θ, grads) where θ == [b, c, U, W] and grads = [gradients(b), gradients(c), gradients(U), gradients(W)] (not really sure what gradients expects as inputs for you).

UPDATE: But to answer your original question, gradients needs to return a Grads object found here: https://github.com/FluxML/Zygote.jl/blob/359e586766129878ca0e56121037ed80afda6289/src/compiler/interface.jl#L88

So something like

# within gradient function body assuming gb is the gradient w.r.t b
g = Zygote.Grads(IdDict())
g.grads[θ[1]] = gb # assuming θ[1] == b

But not using Params is probably simpler to debug. The only issue is that there isn't a update! that will work over an array of parameters, but you could easily define your own:

function Flux.Optimise.update!(opt, xs::Tuple, gs)
    for (x, g) in zip(xs, gs)
        update!(opt, x, g)
    end
end

# use it like this
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = (b, c, U, W)

opt = ADAM(0.01)
x = # generate input to gradients
grads = gradients(x) # return tuple (gb, gc, gU, gW)
update!(opt, θ, grads)

UPDATE 2:

Another option is to still use Zygote to take the gradients so that it automatically sets the Grads object for you, but to use a custom adjoint so that it uses your analytical function to compute the adjoint. Let's assume your ML model is defined as function called f, so that f(x) returns the output of your model for input x. Let's also assume that gradients(x) returns the analytical gradients w.r.t. x like you mentioned in your question. Then the following code will still use Zygote's AD which will populate the Grads object correctly, but it will use your definition of computing the gradients for your function f:

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

f(x) = # define your model
gradients(x) = # define your analytical gradient

# set up the custom adjoint
Zygote.@adjoint f(x) = f(x), Δ -> (gradients(x),)

opt = ADAM(0.01)
x = # generate input to model
y = # output of model
grads = Zygote.gradient(() -> Flux.mse(f(x), y), θ)
update!(opt, θ, grads)

Notice that I used Flux.mse as an example loss above. One downside to this approach is that Zygote's gradient function wants a scalar output. If your model is being passed into some loss that will output a scalar error value, then @adjoint is the best approach. This would be appropriate for the situation where you are doing standard ML and the only change is that you wish for Zygote to compute the gradient of f analytically using your function.

If you are doing something more complex and cannot use Zygote.gradient, then the first approach (not using Params) is most appropriate. Params really only exists for backwards compatibility with Flux's old AD, so it is best to avoid it if possible.

Sundaysundberg answered 16/4, 2020 at 17:3 Comment(1)
I used your first recommendation and that works like a charm. Thanks so much!Frauenfeld

© 2022 - 2024 — McMap. All rights reserved.