Is it ok to call `tape.watch(x)` when `x` is already a `tf.Variable` in TensorFlow?
Asked Answered
Z

2

6

Consider the following function

def foo(x):
  with tf.GradientTape() as tape:
    tape.watch(x)

    y = x**2 + x + 4

  return tape.gradient(y, x)

The call to tape.watch(x) is necessary if the function is called say as foo(tf.constant(3.14)), but is not when it is passed in a variable directly, such as foo(tf.Variable(3.14)).

Now my question is, is the call to tape.watch(x) safe even in the case when tf.Variable is passed in directly? Or will some strangness happen due to the variable already being auto-watched and then watched manually again? What is the correct way to write general functions like this that can accept both tf.Tensor and tf.Variable?

Zeeba answered 1/2, 2019 at 12:42 Comment(0)
B
7

It should be safe. On the one hand, the documentation of tf.GradientTape.watch says:

Ensures that tensor is being traced by this tape.

"Ensures" seems to imply that it will make sure it is traced in case it is not. In fact, the documentation does not give any indication that using it twice over the same object should be a problem (although it wouldn't hurt if they made that explicit).

But in any case, we can dig into the source code to check. In the end, calling watch on a variable (the answer ends up the same if it's not a variable but the path diverges slightly) comes down to the WatchVariable method of a GradientTape class in C++:

void WatchVariable(PyObject* v) {
  tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
  if (handle == nullptr) {
    return;
  }
  tensorflow::int64 id = FastTensorId(handle.get());

  if (!PyErr_Occurred()) {
    this->Watch(id);
  }

  tensorflow::mutex_lock l(watched_variables_mu_);
  auto insert_result = watched_variables_.emplace(id, v);

  if (insert_result.second) {
    // Only increment the reference count if we aren't already watching this
    // variable.
    Py_INCREF(v);
  }
}

The second half of the method shows that the watched variable is added to watched_variables_, which is a std::set, so adding again something will not do anything. This is actually checked later to make sure Python reference counting is correct. The first half basically calls Watch:

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
    int64 tensor_id) {
  tensor_tape_.emplace(tensor_id, -1);
}

tensor_tape_ is a map (specifically a tensorflow::gtl:FlatMap, pretty much the same as a standard C++ map), so if tensor_id is already there this will have no effect.

So, even though it is not explicitly stated, everything suggests there should be no issues with it.

Brokerage answered 1/2, 2019 at 14:11 Comment(2)
Thank you. So I thought that all variables in TF are trainable by default. So, if we let variable trainable, then what is the point behind watching it please?Moray
@Moray When you create a tf.GradientTape, you can pass watch_accessed_variables to decide whether or not accessed variables should be automatically watched. If you are only going to modify a subset of the variables that you are using in the model, you can set it to False and then manually use watch only on those variables we are interested in. Also, you can use watch on non-variable tensors in order to be able to compute other gradients.Brokerage
P
1

It's designed to be used by variables. From the docs

By default GradientTape will automatically watch any trainable variables that are accessed inside the context. If you want fine grained control over which variables are watched you can disable automatic tracking by passing watch_accessed_variables=False to the tape constructor:

with tf.GradientTape(watch_accessed_variables=False) as tape:
  tape.watch(variable_a)
  y = variable_a ** 2  # Gradients will be available for `variable_a`.
  z = variable_b ** 3  # No gradients will be available since `variable_b` is
                       # not being watched.
Peba answered 7/4, 2020 at 15:19 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.