Using quantile in Flux (Julia) in loss function
Asked Answered
Q

1

7

I am trying to use quantile in a loss function to train! (for some robustness, like least trimmed squares), but it mutates the array and Zygote throws an error Mutating arrays is not supported, coming from sort! . Below is a simple example (the content does not make sense of course):

using Flux, StatsBase
xdata = randn(2, 100)   
ydata = randn(100)

model = Chain(Dense(2,10), Dense(10, 1))


function trimmedLoss(x,y; trimFrac=0.f05)
        yhat = model(x)
        absRes = abs.(yhat .- y) |> vec
        trimVal = quantile(absRes, 1.f0-trimFrac) 
        s = sum(ifelse.(absRes .> trimVal,  0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
        #s = sum(absRes)/length(absRes)   # using this and commenting out the two above works (no surprise)    
end

println(trimmedLoss(xdata, ydata)) #works ok

Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())

println(trimmedLoss(xdata, ydata)) #changed loss?

This is all in Flux 0.10 with Julia 1.2

Thanks in advance for any hints or workaround!

Quadricycle answered 16/1, 2020 at 20:33 Comment(3)
isn't this like text book example of a nondifferentiable loss function? If so, I have seen some hacky code that trains Flux NNs with gradient free optimizers, but I don't have any handy.Cato
Perhaps you can decompose the problem into a differentiable loss and the projection onto a convex set, and then implement a projected gradient method? OTOH, I'm not sure whether that's sensible, given that the loss will not be convex anyway...Tellurite
Yes, certainly strictly it is not differentiable. But, heuristically it works quite well with other optimizers (Levenberg-Marq etc.), in particular when the trimmed fraction is small. With a stochastic gradient descent it should even matter less IMHO. It worked in an earlier version of Flux, without Zygote.Quadricycle
A
9

Ideally, we'd define a custom adjoint for quantile so that this works out of the box. (Feel free to open an issue to remind us to do this.)

In the mean time there's a quick workaround. It's actually the sorting that causes trouble here so if you do quantile(xs, p, sorted=true) it'll work. Obviously this requires xs to be sorted to get correct results, so you might need to use quantile(sort(xs), ...).

Depending on your Zygote version you might also need an adjoint for sort. That one's pretty easy:

julia> using Zygote: @adjoint

julia> @adjoint function sort(x)
         p = sortperm(x)
         x[p], x̄ -> (x̄[invperm(p)],)
       end

julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)

We'll make that built-in in the next Zygote release, but for now if you add that to your script it'll get your code working.

Atlanta answered 20/1, 2020 at 11:38 Comment(1)
You Sir, are my hero and saviour.Twofold

© 2022 - 2024 — McMap. All rights reserved.