Is (Convolution + PixelShuffle) the same as SubPixel convolution?
Asked Answered
M

0

8

I am working on the ArXiV paper Is the deconvolution layer the same as a convolutional layer?. The topic is to upsample, by an upscale factor r, a tensor x from shape (N, C, H, W) to shape (N, C, H*r, W*r).

In that paper they introduce an equivalence between the two following methods (the point being the second one should be more computationally efficient than the first one):

  • SubPixel method: create sub pixel image, then convolution in sub pixel space with a kernel of shape (C, C, Hk * r, Wk * r)
  • PixelShuffle method: convolution with kernel (C * r * r, C, Hk, Wk), then periodic shuffling

Those methods are pictured by the Figure 6 and 7 of the paper, that I reproduced and animated below to highlight my interrogations about the equivalence statement. After the figures comes a pytorch implementation to support those interrogations.

Questions

Am I missing something or the following is true?:

  1. SubPixel output size:
    • SubPixel method gives an output image of size (6, 6) and not (8, 8) as represented in the figure.
    • SubPixel method can give an output image of size (8, 8) by adding more padding in the subpixel space, but then last row and last column of the output image are full of zeros.
  2. Except for "purple" pixels, there is a spatial shift of values between both output images:
    • Purple pixels are aligned
    • Blue pixels are shifted on dim W
    • Green pixels are shifted on dim H
    • Red pixels are shifted on dim H and dim W
    • e.g. :
      • PixShuff: Output blue pixel (0, 1) is computed from input pixels [(0, 0), (0, 1), (1, 0), (1, 1)]
      • SubPix: Output blue pixel (0, 1) is computed from input pixels [(0, 1), (0, 2), (1, 1), (1, 2)]. In PixShuff, those input pixels are used to compute the output blue pixel (0, 3).

Figures

SubPixel method

SubPixel  (Figure 6)

PixelShuffle method

PixelShuffle (Figure 7)

Implementation

Notes:

  • I added some assertions to make sure I understand the dimensions correctly.
  • There is still some hard-coded values for padding etc.

The main script:

import torch
import torch.nn.functional as F

torch.set_printoptions(precision=2)
torch.manual_seed(34)


def main():
    N, Cin, H, W = 1, 1, 4, 4
    Cout = Cin
    ratio = 2
    Hk, Wk = 2, 2

    x = torch.rand(N, Cin, H, W)
    x_padded = torch.zeros(N, Cin, H + 1, W + 1)
    x_padded[..., :-1, :-1] = x

    kernel_stacked = torch.rand(Cout * ratio * ratio, Cin, Hk, Wk)
    kernel_shuffled = F.pixel_shuffle(kernel_stacked.movedim(0, 1), ratio)

    x_up_pixshuff = upsample_conv_pixshuff(x_padded, ratio, kernel_stacked)
    x_up_subpix = upsample_subpix_conv(x, ratio, kernel_shuffled)
    x_up_subpix_padded = upsample_subpix_conv(x_padded, ratio, kernel_shuffled)

    purple_subpix = x_up_subpix_padded[..., ::2, ::2]
    purple_pixshuff = x_up_pixshuff[..., ::2, ::2]
    blue_subpix = x_up_subpix_padded[..., ::2, 1::2]
    blue_pixshuff = x_up_pixshuff[..., ::2, 1::2]
    green_subpix = x_up_subpix_padded[..., 1::2, ::2]
    green_pixshuff = x_up_pixshuff[..., 1::2, ::2]
    red_subpix = x_up_subpix_padded[..., 1::2, 1::2]
    red_pixshuff = x_up_pixshuff[..., 1::2, 1::2]

    # ... Long list of print statements given below


def upsample_conv_pixshuff(x, upscale_factor, kernel):
    N, Cin, H, W = x.shape
    Cout_r_r, _, H, W = kernel.shape
    assert kernel.shape[1] == Cin
    assert Cout_r_r == Cin * upscale_factor * upscale_factor

    x_up_stacked = F.conv2d(x, kernel)
    return F.pixel_shuffle(x_up_stacked, upscale_factor)


def upsample_subpix_conv(x, upscale_factor, kernel):
    N, Cin, Hin, Win = x.shape
    Cout, _, Hk, Wk = kernel.shape
    assert kernel.shape[1] == Cin
    assert Cout == Cin
    Hout, Wout = Hin * upscale_factor, Win * upscale_factor

    x_interleaved = torch.zeros(N, Cout, Hout + 1, Wout + 1)
    x_interleaved[..., :-1:upscale_factor, :-1:upscale_factor] = x
    return F.conv2d(x_interleaved, kernel)


if __name__ == "__main__":
    main()

The output given by the long list of print statements:

-> SubPixel method gives an output image of size `(6, 6)` and not (8, 8)

x_up_pixshuff.shape = torch.Size([1, 1, 8, 8])
x_up_subpix.shape   = torch.Size([1, 1, 6, 6])

-> SubPixel method can give an output image of size `(8, 8)`
by adding more padding in the subpixel space,
but then ouptput last row/col image are full of zeros

x_up_subpix_padded.shape = torch.Size([1, 1, 8, 8])
torch.all(x_up_subpix_padded[..., :, -1] == 0) = tensor(True)
torch.all(x_up_subpix_padded[..., -1, 0] == 0) = tensor(True)

-> There is a spatial shift between output images
    - Purple pixels are aligned
torch.all(purple_subpix == purple_pixshuff) = tensor(True)
    - Blue pixels are shifted on dim W
torch.all(blue_subpix[..., :-1] == blue_pixshuff[..., 1:]) = tensor(True)
    - Green pixels are shifted on dim H
torch.all(green_subpix[..., :-1, :] == green_pixshuff[..., 1:, :]) = tensor(True)
    - Red pixels are shifted on dim H and dim W
torch.all(red_subpix[..., :-1, :-1] == red_pixshuff[..., 1:, 1:]) = tensor(True)

The complete matrices:

x_up_pixshuff
tensor([[[[0.73, 0.43, 0.26, 0.33, 0.74, 0.64, 0.36, 0.17],
          [1.17, 1.29, 0.17, 0.82, 0.61, 1.24, 0.79, 0.69],
          [0.88, 0.62, 0.41, 0.21, 0.56, 0.50, 0.56, 0.28],
          [1.15, 1.35, 1.07, 0.87, 0.89, 1.36, 1.12, 1.03],
          [0.92, 0.90, 1.21, 0.93, 0.94, 0.78, 0.64, 0.37],
          [0.84, 1.94, 1.67, 2.35, 1.32, 2.08, 0.96, 1.03],
          [0.36, 0.44, 0.78, 0.71, 0.63, 0.55, 0.34, 0.23],
          [0.23, 1.09, 0.61, 1.57, 0.51, 1.20, 0.30, 0.45]]]])

x_up_subpix
tensor([[[[0.73, 0.33, 0.26, 0.64, 0.74, 0.17],
          [1.15, 0.87, 1.07, 1.36, 0.89, 1.03],
          [0.88, 0.21, 0.41, 0.50, 0.56, 0.28],
          [0.84, 2.35, 1.67, 2.08, 1.32, 1.03],
          [0.92, 0.93, 1.21, 0.78, 0.94, 0.37],
          [0.23, 1.57, 0.61, 1.20, 0.51, 0.45]]]])

x_up_subpix_padded
tensor([[[[0.73, 0.33, 0.26, 0.64, 0.74, 0.17, 0.36, 0.00],
          [1.15, 0.87, 1.07, 1.36, 0.89, 1.03, 1.12, 0.00],
          [0.88, 0.21, 0.41, 0.50, 0.56, 0.28, 0.56, 0.00],
          [0.84, 2.35, 1.67, 2.08, 1.32, 1.03, 0.96, 0.00],
          [0.92, 0.93, 1.21, 0.78, 0.94, 0.37, 0.64, 0.00],
          [0.23, 1.57, 0.61, 1.20, 0.51, 0.45, 0.30, 0.00],
          [0.36, 0.71, 0.78, 0.55, 0.63, 0.23, 0.34, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]]]])

The list of print statements for reference:

print("-> SubPixel method gives an output image of size `(6, 6)` and not (8, 8)\n")
    print(f"{x_up_pixshuff.shape = }")
    print(f"{x_up_subpix.shape   = }")

    print(
        "\n-> SubPixel method can give an output image of size `(8, 8)`"
        "\nby adding more padding in the subpixel space,"
        "\nbut then ouptput last row/col image are full of zeros\n"
    )
    print(f"{x_up_subpix_padded.shape = }")
    print(f"{torch.all(x_up_subpix_padded[..., :, -1] == 0) = }")
    print(f"{torch.all(x_up_subpix_padded[..., -1, 0] == 0) = }")

    print("\n-> There is a spatial shift between output images")
    print("    - Purple pixels are aligned")
    print(f"{torch.all(purple_subpix == purple_pixshuff) = }")
    print("    - Blue pixels are shifted on dim W")
    print(f"{torch.all(blue_subpix[..., :-1] == blue_pixshuff[..., 1:]) = }")
    print("    - Green pixels are shifted on dim H")
    print(f"{torch.all(green_subpix[..., :-1, :] == green_pixshuff[..., 1:, :]) = }")
    print("    - Red pixels are shifted on dim H and dim W")
    print(f"{torch.all(red_subpix[..., :-1, :-1] == red_pixshuff[..., 1:, 1:]) = }")

    print("\n-> For reference, the complete matrices:")
    print("\nx_up_pixshuff")
    print(x_up_pixshuff)
    print("\nx_up_subpix")
    print(x_up_subpix)
    print("\nx_up_subpix_padded")
    print(x_up_subpix_padded)
Misalliance answered 30/4, 2022 at 13:58 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.