How do I swap tensor's axes in TensorFlow?
Asked Answered
M

3

47

I have a tensor of shape (30, 116, 10), and I want to swap the first two dimensions, so that I have a tensor of shape (116, 30, 10)

I saw that numpy as such a function implemented (np.swapaxes) and I searched for something similar in tensorflow but I found nothing.

Do you have any idea?

Mystic answered 5/7, 2016 at 20:31 Comment(0)
Y
61

tf.transpose provides the same functionality as np.swapaxes, although in a more generalized form. In your case, you can do tf.transpose(orig_tensor, [1, 0, 2]) which would be equivalent to np.swapaxes(orig_np_array, 0, 1).

Yasui answered 5/7, 2016 at 20:55 Comment(2)
What if I don't know the dimensions of my input tensor but I'm sure I want to swap the last 2 axes? Like, what should I do to a tensor variable so that an input of shape (2, 3, 4, 5) will end up as (2, 3, 5, 4) but the same should work on an input of shape (3, 4, 5, 6, 7) (and turn it into (3, 4, 5, 7, 6))Ba
@KonstantinosBairaktaris see my answerTribune
S
7

It is possible to use tf.einsum to swap axes if the number of input dimensions is unknown. For example:

  • tf.einsum("ij...->ji...", input) will swap the first two dimensions of input;
  • tf.einsum("...ij->...ji", input) will swap the last two dimensions;
  • tf.einsum("aij...->aji...", input) will swap the second and the third dimension;
  • tf.einsum("ijk...->kij...", input) will permute the first three dimensions;

and so on.

Spitler answered 20/4, 2020 at 17:50 Comment(1)
I've heard tf.einsum is really slowTribune
T
6

You can transpose just the last two axes with tf.linalg.matrix_transpose, or more generally, you can swap any number of trailing axes by working out what the leading indices are dynamically, and using relative indices for the axes you want to transpose

x = tf.ones([5, 3, 7, 11])
trailing_axes = [-1, -2]

leading = tf.range(tf.rank(x) - len(trailing_axes))   # [0, 1]
trailing = trailing_axes + tf.rank(x)                 # [3, 2]
new_order = tf.concat([leading, trailing], axis=0)    # [0, 1, 3, 2]
res = tf.transpose(x, new_order)
res.shape                                             # [5, 3, 11, 7]
Tribune answered 7/12, 2020 at 21:31 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.