keras combining two losses with adjustable weights
Asked Answered
A

1

8

enter image description here

So here is the detail description. I have a keras functional model with two layers with outputs x1 and x2.

x1 = Dense(1,activation='relu')(prev_inp1)

x2 = Dense(2,activation='relu')(prev_inp2)

I need to use these x1 and x2, Merge/add Them and come up with weighted loss function like in the attached image. Propagate the 'same loss' into both branches. Alpha is flexible to vary with iterations

Afford answered 11/9, 2017 at 20:32 Comment(0)
M
6

It seems that propagating the "same loss" into both branches will not take effect, unless alpha is dependent on both branches. If alpha is not variable depending on both branches, then part of the loss will be just constant to one branch.

So, in this case, just compile the model with the two losses separate and add the weights to the compile method:

model.compile(optmizer='someOptimizer',loss=[loss1,loss2],loss_weights=[alpha,1-alpha])

Compile again when you need alpha to change.


But if indeed alpha is dependent on both branches, then you need to concatenate the results and calculate alpha's value:

singleOut = Concatenate()([x1,x2])

And a custom loss function:

def weightedLoss(yTrue,yPred):
    x1True = yTrue[0]
    x2True = yTrue[1:]

    x1Pred = yPred[0]
    x2Pred = yPred[1:]

    #calculate alpha somehow with keras backend functions

    return (alpha*(someLoss(x1True,x1Pred)) + ((1-alpha)*(someLoss(x2True,x2Pred))

Compile with this function:

model.compile(loss=weightedLoss, optimizer=....)
Marvamarve answered 12/9, 2017 at 0:22 Comment(2)
I need to propagate L_total in both the branches. So, the Second approach might work. Also In the weightedLoss can we have iteration as an input? I need to vary alpha according to iterations. or Do I need to write a routine in callbacks?Afford
Probably a routine in callbacks... I'm not aware of any ways to hack the loss function to access more than yTrue and yPred.Ballplayer

© 2022 - 2024 — McMap. All rights reserved.