Keras allows you to take any tensors from global scope. Actually, y_true
and y_pred
might be even not used, as here.
Your model can have multiple inputs (you can make this input dummy on inference, or load weights into model with single input). Notice, that you still need it for validation.
import keras
from keras.layers import *
from keras import backend as K
import numpy as np
inputs_x = Input(shape=(10,))
inputs_w = Input(shape=(10,))
y = Dense(10,kernel_initializer='glorot_uniform' )(inputs_x)
model = keras.Model(inputs=[inputs_x, inputs_w], outputs=[y])
def my_loss(y_true, y_pred):
return K.abs((y_true-y_pred)*inputs_w)
def my_metrics(y_true, y_pred):
# just to output something
return K.mean(inputs_w)
model.compile(optimizer='adam', loss=[my_loss], metrics=[my_metrics])
data = np.random.normal(size=(50000, 10))
labels = np.random.normal(size=(50000, 10))
weights = np.random.normal(size=(50000, 10))[data, weights], labels, batch_size=256, validation_data=([data[:100], weights[:100]], labels[:100]), epochs=100)
To make validation without weights, you need to compile another version of the model with different loss which does not use weights.
UPD: Also notice, that Keras will sum up all the elements of your loss, if it returns array instead of scalar
UPD: Tor tensorflow 2.1.0 things become more complicated, it seems. The way to go is in the direction @marco-cerliani pointed out (labels, weighs and data are fed to the model and custom loss tensor is added via .add_loss()
), however his solution didn't work for me out of the box. The first thing is that model does not want to work with None loss, refusing to take both inputs and outputs. So, I introduced additional dummy loss function. The second problem appeared when dataset size was not divisible by batch size. In keras and tf 1.x last batch problem was usually solved by steps_per_epoch
and validation_steps
parameters, but here if starts to fail on the first batch of Epoch 2. So I needed to make simple custom data generator.
import tensorflow.keras as keras
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
import numpy as np
inputs_x = Input(shape=(10,))
inputs_w = Input(shape=(10,))
inputs_l = Input(shape=(10,))
y = Dense(10,kernel_initializer='glorot_uniform' )(inputs_x)
model = keras.Model(inputs=[inputs_x, inputs_w, inputs_l], outputs=[y])
def my_loss(y_true, y_pred):
return K.abs((y_true-y_pred)*inputs_w)
def my_metrics():
# just to output something
return K.mean(inputs_w)
def dummy_loss(y_true, y_pred):
return 0.
loss = my_loss(y, inputs_l)
metric = my_metrics()
model.add_metric(metric, name='my_metric', aggregation='mean')
model.compile(optimizer='adam', loss=dummy_loss)
data = np.random.normal(size=(50000, 10))
labels = np.random.normal(size=(50000, 10))
weights = np.random.normal(size=(50000, 10))
dummy = np.zeros(shape=(50000, 10)) # or in can be labels, no matter now
# looks like it does not like when len(data) % batch_size != 0
# If I set steps_per_epoch, it fails on the second epoch.
# So, I proceded with data generator
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, x, w, y, y2, batch_size, shuffle=True):
self.x = x
self.w = w
self.y = y
self.y2 = y2
self.indices = list(range(len(self.x)))
self.shuffle = shuffle
self.batch_size = batch_size
def __len__(self):
'Denotes the number of batches per epoch'
return len(self.indices) // self.batch_size
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
ids = self.indices[index*self.batch_size:(index+1)*self.batch_size]
# the last None to remove weird warning
return [self.x[ids], self.w[ids], self.y[ids]], self.y2[ids], [None]
def on_epoch_end(self):
'Updates indexes after each epoch'
if self.shuffle == True:
batch_size = 256
train_generator = DataGenerator(data,weights,labels, dummy, batch_size=batch_size, shuffle=True)
val_generator = DataGenerator(data[:2*batch_size],weights[:2*batch_size],labels[:2*batch_size], dummy[:2*batch_size], batch_size=batch_size, shuffle=True), validation_data=val_generator,epochs=100)