If you don't use Params
then grads
just needs to be the gradient. The only requirement is that θ
and grads
are the same size.
For example, map((x, g) -> update!(opt, x, g), θ, grads)
where θ == [b, c, U, W]
and grads = [gradients(b), gradients(c), gradients(U), gradients(W)]
(not really sure what gradients
expects as inputs for you).
UPDATE: But to answer your original question, gradients
needs to return a Grads
object found here: https://github.com/FluxML/Zygote.jl/blob/359e586766129878ca0e56121037ed80afda6289/src/compiler/interface.jl#L88
So something like
# within gradient function body assuming gb is the gradient w.r.t b
g = Zygote.Grads(IdDict())
g.grads[θ[1]] = gb # assuming θ[1] == b
But not using Params
is probably simpler to debug. The only issue is that there isn't a update!
that will work over an array of parameters, but you could easily define your own:
function Flux.Optimise.update!(opt, xs::Tuple, gs)
for (x, g) in zip(xs, gs)
update!(opt, x, g)
end
end
# use it like this
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = (b, c, U, W)
opt = ADAM(0.01)
x = # generate input to gradients
grads = gradients(x) # return tuple (gb, gc, gU, gW)
update!(opt, θ, grads)
UPDATE 2:
Another option is to still use Zygote to take the gradients so that it automatically sets the Grads
object for you, but to use a custom adjoint so that it uses your analytical function to compute the adjoint. Let's assume your ML model is defined as function called f
, so that f(x)
returns the output of your model for input x
. Let's also assume that gradients(x)
returns the analytical gradients w.r.t. x
like you mentioned in your question. Then the following code will still use Zygote's AD which will populate the Grads
object correctly, but it will use your definition of computing the gradients for your function f
:
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = Flux.Params([b, c, U, W])
f(x) = # define your model
gradients(x) = # define your analytical gradient
# set up the custom adjoint
Zygote.@adjoint f(x) = f(x), Δ -> (gradients(x),)
opt = ADAM(0.01)
x = # generate input to model
y = # output of model
grads = Zygote.gradient(() -> Flux.mse(f(x), y), θ)
update!(opt, θ, grads)
Notice that I used Flux.mse
as an example loss above. One downside to this approach is that Zygote's gradient
function wants a scalar output. If your model is being passed into some loss that will output a scalar error value, then @adjoint
is the best approach. This would be appropriate for the situation where you are doing standard ML and the only change is that you wish for Zygote to compute the gradient of f
analytically using your function.
If you are doing something more complex and cannot use Zygote.gradient
, then the first approach (not using Params
) is most appropriate. Params
really only exists for backwards compatibility with Flux's old AD, so it is best to avoid it if possible.