How to properly raise exception in TensorFlow graph
Asked Answered
P

1

6

I want to raise a tf.errors.InvalidArgumentError exception dependent on the value of an input tensor in graph mode (in TensorFlow serving).

Currently I use tf.debugging.assert_all_finite and this works fine. As I'm not making an assertion to bug check, but raising an exception based on the input, it would be better to raise an explicit exception.

My question boils down to:

  • How to conditionally execute code that doesn't return a tensor
  • How to raise a tf.errors exception.

What is the proper way of doing this?

Edit: Some more detail. I would like to recreate the following logic without using tf.debugging (unless that is in fact the correct way to do it).

Currently I am checking that there aren't NaN values like this:

assert_op = tf.debugging.assert_all_finite(
    input_data,
    'Cant have nans at beginning or end'
)
Panthea answered 29/4, 2020 at 13:42 Comment(3)
You cannot raise or catch an exception "inside" the graph. The graph might be run on any device (e.g. GPU) which doesn't even know what an exception is. However the graph can contain control flow structures such as tf.cond which can be used to emulate that behavior. Note that tf.errors.InvalidArgumentError is raised at compile time, not at runtime.Scarper
@a_guest: You can raise an exception at runtime, e.g. via tf.Assert. You cannot catch it (within the graph), though. I have expanded upon that in my answer.Aesthetic
@Aesthetic Yes you can have an exception raised during runtime but that doesn't happen inside the graph. tf.Assert uses tf.cond under the covers to create the logical branches. The exception is then raised in the "controller" space. Maybe I misunderstood the question but it appeared that the goal was to raise an exception during runtime. This is quite different from having tensorflow raise an exception - not with respect to the result but regarding the implementation. Nevertheless I agree that tf.Assert is exactly the right option here.Scarper
A
3

As you wrote me via mail, this might be related to this TF issue about catching exceptions within the graph execution, and this related SO question. However, I'm not exactly sure this is really relevant for you. This TF issue and SO question was about how to dynamically catch an exception, so basically implementing try: ... except: ... in the TF graph.

Other TF functionality which introduce control structure are:

  • tf.while_loop
  • tf.cond

tf.cond is the answer to your question how you conditionally execute code. Depending on a condition, i.e. a bool scalar. But maybe that is not really your question, but rather, how to formulate the condition?

There is tf.check_numerics which checks tensors for inf/nan and throws an exception if such a tensor is found.

If you want to get that as a condition, you could use this code:

is_finite = tf.reduce_all(tf.is_finite(x))

If you want to throw an exception if some condition is not true, you can do:

check_op = tf.Assert(is_finite, ["Tensor had inf or nan values:", x])

You might want to use tf.control_dependencies to make sure this op check_op gets executed.

Aesthetic answered 30/4, 2020 at 10:53 Comment(1)
Thanks for your help. In TensorFlow 2 tf.debugging.assert_all_finite and tf.debugging.check_numerics both seem to do the job. I was confused by the assert keyword, I thought that there might be a way to explicitly raise specific exceptions. This does the job. Thanks for your quick and thorough answer.Panthea

© 2022 - 2024 — McMap. All rights reserved.