Custom eval_metric_ops in Estimator in Tensorflow
Asked Answered
P

2

6

I am trying to add the r squared in the eval_metric_ops in my estimator like this:

def model_fn(features, labels, mode, params):
    predict = prediction(features, params, mode)
    loss = my_loss_fn
    eval_metric_ops = { 
        'rsquared': tf.subtract(1.0, tf.div(tf.reduce_sum(tf.squared_difference(label, tf.reduce_sum(tf.squared_difference(labels, tf.reduce_mean(labels)))),
                                   name = 'rsquared')
        }

    train_op = tf.contrib.layers.optimize_loss(
        loss = loss,
        global_step = global_step,
        learning_rate = 0.1,
        optimizer = "Adam"
    )

    predictions = {"predictions": predict}

    return tf.estimator.EstimatorSpec(
        mode = mode,
        predictions = predictions,
        loss = loss,
        train_op = train_op,
        eval_metric_ops = eval_metric_ops
    )

but I have the following error:

TypeError: Values of eval_metric_ops must be (metric_value, update_op) tuples, given: Tensor("rsquared:0", shape=(), dtype=float32) for key: rsquared

I tried without the name argument too but does not change anything. Do you know how to create this eval_metric_ops ?

Plaint answered 11/8, 2017 at 21:9 Comment(0)
T
6

eval_metric_opsneeds a dict of metric results keyed by name. The values of the dict are the results of calling a metric function. The metric function in your case can be implemented using tf.metrics:

 def metric_fn(labels, predict):
    SST, update_op1 = tf.metrics.mean_squared_error(labels, tf.reduce_mean(labels))
    SSE, update_op2 = tf.metrics.mean_squared_error(labels, predictions )
    return tf.subtract(1.0, tf.div(SSE, SST)), tf.group(update_op1, update_op2))
Thibodeaux answered 12/8, 2017 at 20:42 Comment(0)
P
0

I tried the accepted answer but it didn't work in TF 1.14, and then I tried to make my own. You can adapt the source code examples here to your own just by changing the function starting with compute_* and the related variables.

Pierce answered 1/7, 2021 at 12:17 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.