conv2d_transpose() simply transposes the weights and flips them by 180 degrees. Then it applies the standard conv2d(). "Transposes" practically means that it changes the order of the "columns" in the weights tensor. Please check the example below.
Here there is an example that uses convolutions with stride=1 and padding='SAME'. It is a simple case but the same reasoning could be applied to the other cases.
Say we have:
- Input: MNIST image of 28x28x1, shape = [28,28,1]
- Convolutional layer: 32 filters of 7x7, weights shape = [7, 7, 1, 32], name = W_conv1
If we perform convolution of the input then the activations of the will have shape: [1,28,28,32].
activations = sess.run(h_conv1,feed_dict={x:np.reshape(image,[1,784])})
Where:
W_conv1 = weight_variable([7, 7, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = conv2d(x, W_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1
To obtain the "deconvolution" or "transposed convolution" we can use conv2d_transpose() on the convolution activations in this way:
deconv = conv2d_transpose(activations,W_conv1, output_shape=[1,28,28,1],padding='SAME')
OR using conv2d() we need to transpose and flip the weights:
transposed_weights = tf.transpose(W_conv1, perm=[0, 1, 3, 2])
Here we change the order of the "colums" from [0,1,2,3] to [0,1,3,2].So from [7, 7, 1, 32] we will obtain a tensor with shape=[7,7,32,1]. Then we flip the weights:
for i in range(n_filters):
# Flip the weights by 180 degrees
transposed_and_flipped_weights[:,:,i,0] = sess.run(tf.reverse(transposed_weights[:,:,i,0], axis=[0, 1]))
Then we can compute the convolution with conv2d() as:
strides = [1,1,1,1]
deconv = conv2d(activations,transposed_and_flipped_weights,strides=strides,padding='SAME')
And we will obtain the same result as before. Also the very same result can be obtained with conv2d_backprop_input() using:
deconv = conv2d_backprop_input([1,28,28,1],W_conv1,activations, strides=strides, padding='SAME')
The results are shown here:
Test of the conv2d(), conv2d_tranposed() and conv2d_backprop_input()
We can see that the results are the same. To see it in a better way please check my code at:
https://github.com/simo23/conv2d_transpose
Here I replicate the output of the conv2d_transpose() function using the standard conv2d().