I am working on the ArXiV paper Is the deconvolution layer the same as a convolutional layer?. The topic is to upsample, by an upscale factor r
, a tensor x
from shape (N, C, H, W)
to shape (N, C, H*r, W*r)
.
In that paper they introduce an equivalence between the two following methods (the point being the second one should be more computationally efficient than the first one):
- SubPixel method: create sub pixel image, then convolution in sub pixel space with a kernel of shape
(C, C, Hk * r, Wk * r)
- PixelShuffle method: convolution with kernel
(C * r * r, C, Hk, Wk)
, then periodic shuffling
Those methods are pictured by the Figure 6 and 7 of the paper, that I reproduced and animated below to highlight my interrogations about the equivalence statement. After the figures comes a pytorch
implementation to support those interrogations.
Questions
Am I missing something or the following is true?:
- SubPixel output size:
- SubPixel method gives an output image of size
(6, 6)
and not(8, 8)
as represented in the figure. - SubPixel method can give an output image of size
(8, 8)
by adding more padding in the subpixel space, but then last row and last column of the output image are full of zeros.
- SubPixel method gives an output image of size
- Except for "purple" pixels, there is a spatial shift of values between both output images:
- Purple pixels are aligned
- Blue pixels are shifted on dim W
- Green pixels are shifted on dim H
- Red pixels are shifted on dim H and dim W
- e.g. :
- PixShuff: Output blue pixel
(0, 1)
is computed from input pixels[(0, 0), (0, 1), (1, 0), (1, 1)]
- SubPix: Output blue pixel
(0, 1)
is computed from input pixels[(0, 1), (0, 2), (1, 1), (1, 2)]
. In PixShuff, those input pixels are used to compute the output blue pixel(0, 3)
.
- PixShuff: Output blue pixel
Figures
SubPixel method
PixelShuffle method
Implementation
Notes:
- I added some assertions to make sure I understand the dimensions correctly.
- There is still some hard-coded values for padding etc.
The main script:
import torch
import torch.nn.functional as F
torch.set_printoptions(precision=2)
torch.manual_seed(34)
def main():
N, Cin, H, W = 1, 1, 4, 4
Cout = Cin
ratio = 2
Hk, Wk = 2, 2
x = torch.rand(N, Cin, H, W)
x_padded = torch.zeros(N, Cin, H + 1, W + 1)
x_padded[..., :-1, :-1] = x
kernel_stacked = torch.rand(Cout * ratio * ratio, Cin, Hk, Wk)
kernel_shuffled = F.pixel_shuffle(kernel_stacked.movedim(0, 1), ratio)
x_up_pixshuff = upsample_conv_pixshuff(x_padded, ratio, kernel_stacked)
x_up_subpix = upsample_subpix_conv(x, ratio, kernel_shuffled)
x_up_subpix_padded = upsample_subpix_conv(x_padded, ratio, kernel_shuffled)
purple_subpix = x_up_subpix_padded[..., ::2, ::2]
purple_pixshuff = x_up_pixshuff[..., ::2, ::2]
blue_subpix = x_up_subpix_padded[..., ::2, 1::2]
blue_pixshuff = x_up_pixshuff[..., ::2, 1::2]
green_subpix = x_up_subpix_padded[..., 1::2, ::2]
green_pixshuff = x_up_pixshuff[..., 1::2, ::2]
red_subpix = x_up_subpix_padded[..., 1::2, 1::2]
red_pixshuff = x_up_pixshuff[..., 1::2, 1::2]
# ... Long list of print statements given below
def upsample_conv_pixshuff(x, upscale_factor, kernel):
N, Cin, H, W = x.shape
Cout_r_r, _, H, W = kernel.shape
assert kernel.shape[1] == Cin
assert Cout_r_r == Cin * upscale_factor * upscale_factor
x_up_stacked = F.conv2d(x, kernel)
return F.pixel_shuffle(x_up_stacked, upscale_factor)
def upsample_subpix_conv(x, upscale_factor, kernel):
N, Cin, Hin, Win = x.shape
Cout, _, Hk, Wk = kernel.shape
assert kernel.shape[1] == Cin
assert Cout == Cin
Hout, Wout = Hin * upscale_factor, Win * upscale_factor
x_interleaved = torch.zeros(N, Cout, Hout + 1, Wout + 1)
x_interleaved[..., :-1:upscale_factor, :-1:upscale_factor] = x
return F.conv2d(x_interleaved, kernel)
if __name__ == "__main__":
main()
The output given by the long list of print statements:
-> SubPixel method gives an output image of size `(6, 6)` and not (8, 8)
x_up_pixshuff.shape = torch.Size([1, 1, 8, 8])
x_up_subpix.shape = torch.Size([1, 1, 6, 6])
-> SubPixel method can give an output image of size `(8, 8)`
by adding more padding in the subpixel space,
but then ouptput last row/col image are full of zeros
x_up_subpix_padded.shape = torch.Size([1, 1, 8, 8])
torch.all(x_up_subpix_padded[..., :, -1] == 0) = tensor(True)
torch.all(x_up_subpix_padded[..., -1, 0] == 0) = tensor(True)
-> There is a spatial shift between output images
- Purple pixels are aligned
torch.all(purple_subpix == purple_pixshuff) = tensor(True)
- Blue pixels are shifted on dim W
torch.all(blue_subpix[..., :-1] == blue_pixshuff[..., 1:]) = tensor(True)
- Green pixels are shifted on dim H
torch.all(green_subpix[..., :-1, :] == green_pixshuff[..., 1:, :]) = tensor(True)
- Red pixels are shifted on dim H and dim W
torch.all(red_subpix[..., :-1, :-1] == red_pixshuff[..., 1:, 1:]) = tensor(True)
The complete matrices:
x_up_pixshuff
tensor([[[[0.73, 0.43, 0.26, 0.33, 0.74, 0.64, 0.36, 0.17],
[1.17, 1.29, 0.17, 0.82, 0.61, 1.24, 0.79, 0.69],
[0.88, 0.62, 0.41, 0.21, 0.56, 0.50, 0.56, 0.28],
[1.15, 1.35, 1.07, 0.87, 0.89, 1.36, 1.12, 1.03],
[0.92, 0.90, 1.21, 0.93, 0.94, 0.78, 0.64, 0.37],
[0.84, 1.94, 1.67, 2.35, 1.32, 2.08, 0.96, 1.03],
[0.36, 0.44, 0.78, 0.71, 0.63, 0.55, 0.34, 0.23],
[0.23, 1.09, 0.61, 1.57, 0.51, 1.20, 0.30, 0.45]]]])
x_up_subpix
tensor([[[[0.73, 0.33, 0.26, 0.64, 0.74, 0.17],
[1.15, 0.87, 1.07, 1.36, 0.89, 1.03],
[0.88, 0.21, 0.41, 0.50, 0.56, 0.28],
[0.84, 2.35, 1.67, 2.08, 1.32, 1.03],
[0.92, 0.93, 1.21, 0.78, 0.94, 0.37],
[0.23, 1.57, 0.61, 1.20, 0.51, 0.45]]]])
x_up_subpix_padded
tensor([[[[0.73, 0.33, 0.26, 0.64, 0.74, 0.17, 0.36, 0.00],
[1.15, 0.87, 1.07, 1.36, 0.89, 1.03, 1.12, 0.00],
[0.88, 0.21, 0.41, 0.50, 0.56, 0.28, 0.56, 0.00],
[0.84, 2.35, 1.67, 2.08, 1.32, 1.03, 0.96, 0.00],
[0.92, 0.93, 1.21, 0.78, 0.94, 0.37, 0.64, 0.00],
[0.23, 1.57, 0.61, 1.20, 0.51, 0.45, 0.30, 0.00],
[0.36, 0.71, 0.78, 0.55, 0.63, 0.23, 0.34, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]]]])
The list of print statements for reference:
print("-> SubPixel method gives an output image of size `(6, 6)` and not (8, 8)\n")
print(f"{x_up_pixshuff.shape = }")
print(f"{x_up_subpix.shape = }")
print(
"\n-> SubPixel method can give an output image of size `(8, 8)`"
"\nby adding more padding in the subpixel space,"
"\nbut then ouptput last row/col image are full of zeros\n"
)
print(f"{x_up_subpix_padded.shape = }")
print(f"{torch.all(x_up_subpix_padded[..., :, -1] == 0) = }")
print(f"{torch.all(x_up_subpix_padded[..., -1, 0] == 0) = }")
print("\n-> There is a spatial shift between output images")
print(" - Purple pixels are aligned")
print(f"{torch.all(purple_subpix == purple_pixshuff) = }")
print(" - Blue pixels are shifted on dim W")
print(f"{torch.all(blue_subpix[..., :-1] == blue_pixshuff[..., 1:]) = }")
print(" - Green pixels are shifted on dim H")
print(f"{torch.all(green_subpix[..., :-1, :] == green_pixshuff[..., 1:, :]) = }")
print(" - Red pixels are shifted on dim H and dim W")
print(f"{torch.all(red_subpix[..., :-1, :-1] == red_pixshuff[..., 1:, 1:]) = }")
print("\n-> For reference, the complete matrices:")
print("\nx_up_pixshuff")
print(x_up_pixshuff)
print("\nx_up_subpix")
print(x_up_subpix)
print("\nx_up_subpix_padded")
print(x_up_subpix_padded)