Flattening two last dimensions of a tensor in TensorFlow
Asked Answered
B

4

8

I'm trying to reshape a tensor from [A, B, C, D] into [A, B, C * D] and feed it into a dynamic_rnn. Assume that I don't know the B, C, and D in advance (they're a result of a convolutional network).

I think in Theano such reshaping would look like this:

x = x.flatten(ndim=3)

It seems that in TensorFlow there's no easy way to do this and so far here's what I came up with:

x_shape = tf.shape(x)
x = tf.reshape(x, [batch_size, x_shape[1], tf.reduce_prod(x_shape[2:])]

Even when the shape of x is known during graph building (i.e. print(x.get_shape()) prints out absolute values, like [10, 20, 30, 40] after the reshaping get_shape() becomes [10, None, None]. Again, still assume the initial shape isn't known so I can't operate with absolute values.

And when I'm passing x to a dynamic_rnn it fails:

ValueError: Input size (depth of inputs) must be accessible via shape inference, but saw value None.

Why is reshape unable to handle this case? What is the right way of replicating Theano's flatten(ndim=n) in TensorFlow with tensors of rank 4 and more?

Bowlder answered 31/10, 2017 at 11:6 Comment(0)
L
4

It is not a flaw in reshape, but a limitation of tf.dynamic_rnn.

Your code to flatten the last two dimensions is correct. And, reshape behaves correctly too: if the last two dimensions are unknown when you define the flattening operation, then so is their product, and None is the only appropriate value that can be returned at this time.

The culprit is tf.dynamic_rnn, which expects a fully-defined feature shape during construction, i.e. all dimensions apart from the first (batch size) and the second (time steps) must be known. It is a bit unfortunate perhaps, but the current implementation does not seem to allow RNNs with a variable number of features, à la FCN.

Littlest answered 21/3, 2018 at 13:1 Comment(1)
Should we change the title of the question therefore? Because this is not on how to reshape anymore but feeding a partial tensor to dynamic_rnn right?Dortch
S
2

I tried a simple code according to your requirements. Since you are trying to reshape a CNN output, the shape of X is same as the output of CNN in Tensorflow.

HEIGHT = 100
WIDTH  = 200
N_CHANELS =3

N_HIDDEN =64

X = tf.placeholder(tf.float32, shape=[None,HEIGHT,WIDTH,N_CHANELS],name='input') # output of CNN

shape = X.get_shape().as_list() # get the shape of each dimention shape[0] =BATCH_SIZE , shape[1] = HEIGHT , shape[2] = HEIGHT = WIDTH , shape[3] = N_CHANELS

input = tf.reshape(X, [-1, shape[1] , shape[2] * shape[3]])
print(input.shape) # prints (?, 100, 600)

#Input for tf.nn.dynamic_rnn should be in the shape of [BATCH_SIZE, N_TIMESTEPS, INPUT_SIZE]     

#Therefore, according to the reshape N_TIMESTEPS = 100 and INPUT_SIZE= 600

#create the RNN here
lstm_layers = tf.contrib.rnn.BasicLSTMCell(N_HIDDEN, forget_bias=1.0)
outputs, _ = tf.nn.dynamic_rnn(lstm_layers, input, dtype=tf.float32)

Hope this helps.

Siler answered 31/10, 2017 at 12:6 Comment(3)
Correct me if there was something I misunderstood, but I've explicitly said "assume the initial shape isn't known so I can't operate with absolute values", which is what you do in shape[2] * shape[3].Bowlder
Correct me if I'm wrong. So A is the batch size, which depends on the number of input samples when training the model.Therefore, we only know A at the training time. But, the output of CNN depends on the input shape, filter shape and the number of output channels and can be calculated or can be checked by using get_shape() method. Moreover, B, C, D should not be equal to None as A.Siler
It would probably be possible, but it would require computing and hardcoding a lot of values just to perform a rather simple transformation of the data. This is what I was trying to avoid (and something Lasagne/Theano lets you accomplish in one line of code). The main part of main question still is "Why is reshape unable to handle this case?"Bowlder
A
0

I found a solution to this by using .get_shape(). Assuming 'x' is a 4-D Tensor.

This will only work with the Reshape Layer. As you were making changes to the architecture of the model, this should work.

x = tf.keras.layers.Reshape(x, [x.get_shape()[0], x.get_shape()[1], x.get_shape()[2] * x.get_shape()][3])

Hope this works!

Affluent answered 21/4, 2021 at 12:54 Comment(0)
M
0

If you use the tf.keras.models.Model or tf.keras.layers.Layer wrapper, the build method provides a nice way to do this.

Here's an example:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv1D, Conv2D, Conv2DTranspose, Attention, Layer, Reshape


class VisualAttention(Layer):
    def __init__(self, channels_out, key_is_value=True):
        super(VisualAttention, self).__init__()

        self.channels_out = channels_out
        self.key_is_value = key_is_value

        self.flatten_images = None  # see build method
        self.unflatten_images = None  # see build method

        self.query_conv = Conv1D(filters=channels_out, kernel_size=1, padding='same')
        self.value_conv = Conv1D(filters=channels_out, kernel_size=4, padding='same')
        self.key_conv = self.value_conv if key_is_value else Conv1D(filters=channels_out, kernel_size=4, padding='same')

        self.attention_layer = Attention(use_scale=False, causal=False, dropout=0.)

    def build(self, input_shape):
        b, h, w, c = input_shape
        self.flatten_images = Reshape((h*w, c), input_shape=(h, w, c))
        self.unflatten_images = Reshape((h, w, self.channels_out), input_shape=(h*w, self.channels_out))

    def call(self, x, training=True):
        x = self.flatten_images(x)
        q = self.query_conv(x)
        v = self.value_conv(x)

        inputs = [q, v] if self.key_is_value else [q, v, self.key_conv(x)]
        output = self.attention_layer(inputs=inputs, training=training)
        return self.unflatten_images(output)

# test
import numpy as np
x = np.arange(8*28*32*3).reshape((8, 28, 32, 3)).astype('float32')
model = VisualAttention(8)
y = model(x)
print(y.shape)
Molasses answered 7/2, 2022 at 0:26 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.