How to use the function merge and switch of tensorflow?
Asked Answered
D

2

8

The merge and switch may not be open to use for general users. And I have searched the source code:

There is a description in merge:

Returns the value of an available element of inputs.

What does it mean available? Is it returned by switch? This is a demo:

from tensorflow.python.ops import control_flow_ops

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
y = control_flow_ops.merge([x_0, x_1, x_2, x_3])
with tf.Session() as sess:
    print(sess.run(y))
Detrude answered 4/11, 2017 at 6:5 Comment(0)
R
8

switch

Let's start by examining the control_flow_ops.switch function:

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
with tf.Session() as sess:
  print(sess.run(x_0))    # prints 2
  print(sess.run(x_3))    # prints 7

control_flow_ops.switch returns a tuple of tensors, but only one of them will have a value (depending on the condition argument). In the example above, it's x_0 = 2 from the first switch and x_3 = 7 from the second one. An attempt to evaluate x_1 or x_2 will result in Retval does not have value error:

  sess.run(x_1)  # FAILS!
  sess.run(x_2)  # FAILS!

In other words, x_0 and x_3 are available, while x_1 or x_2 aren't.

merge

control_flow_ops.merge performs an inverse op: given a tuple of tensors, it selects the available one. Precisely, it returns a named tuple ["output", "value_index"] of a tensor that has a value. According to the current doc, the input should contain exactly one available tensor, this means that your demo is strictly speaking unsupported and leads to undefined behavior. Here's an example:

with tf.Session() as sess:
  print(sess.run(merge([x_0, x_1])))       # Merge(output=2, value_index=0)
  print(sess.run(merge([x_1, x_0])))       # Merge(output=2, value_index=1)
  print(sess.run(merge([x_2, x_3])))       # Merge(output=7, value_index=1)
  print(sess.run(merge([x_3, x_2])))       # Merge(output=7, value_index=0)
  print(sess.run(merge([x_0, x_1, x_2])))  # Merge(output=2, value_index=0)
  print(sess.run(merge([x_1, x_2, x_3])))  # Merge(output=7, value_index=2)

Both of these functions can be handy to control computation flow, e.g. control_flow_ops.switch gradient is implemented through switch itself (tensorflow source code).

Rajab answered 4/11, 2017 at 12:23 Comment(2)
If the number of available tensor is more than one, the doc says it will raise an error, not unsupported. And do you think this ops needs backpropagation, I think it is just control flow, like c++ case block.Detrude
I wrote "your demo is unsupported", which is true. In 1.3 it doesn't raise an error, probably will raise in 1.4. Every node participates in backprop, switch is no exception.Rajab
E
0

maybe you can try this demo.

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
with tf.Session() as sess:
  print("anchor, output:{}".format(sess.run(x_0)))    # prints 2
  print("anchor, output:{}".format(sess.run(x_3)))    # prints 7

merge_0 = control_flow_ops.merge([x_0, x_2])
with tf.Session() as sess:
  print("anchor, output:{}".format(sess.run(merge_0)))    # Merge(output=2, value_index=1)
Excreta answered 18/8, 2020 at 7:30 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.