How to handle odd resolutions in Unet architecture PyTorch
Asked Answered
W

1

13

I'm implementing a U-Net based architecture in PyTorch. At train time, I've patches of size 256x256 which doesn't cause any problem. However at test time, I've full HD images (1920x1080). This is causing a problem during skip connections.

Downsampling 1920x1080 3 times gives 240x135. If I downsample one more time, the resolution becomes 120x68 which when upsampled gives 240x136. Now, I cannot concatenate these two feature maps. How can I solve this?

PS: I thought this is a fairly common problem, but I didn't get any solution or even mentioning of this problem anywhere on the web. Am I missing something?

Waligore answered 3/2, 2021 at 13:41 Comment(6)
Have you tried down-sampling with a different factor along each dimension using torch.nn.MaxPool2d? You could use fix kernel size = (8, 5) this would give you 240 x 216 then you could pad the array to meet the require size 256 x 256without distorting the image too much.Cynthia
PS : I suggested Maxpooling, but could be AveragePoolingas well.Cynthia
No. I can't use that. Actually, I'm benchmarking Partial ConvNet paper for my research. I don't know if I can modify the architecture like that.Waligore
May I ask why you think this would not work ? Or is it simply the fact that it would mean tempering with the architecture. ThanksCynthia
Yes. It's simply about tampering with the architecture. Plus using different max pooling would cause problem during training right?Waligore
If I got your OP right, it would be for the skip connection. You would not be tempering with the architecture itself, your would be adding pre-processing step on the input data before entering the network. Now if your goal is image segmentation, instead of down-sampling you could also try splitting your HD pics into 256 x 256 chunk, this way you don't temper with the resolution.Cynthia
O
5

It is a very common problem in segmentation networks where skip-connections are often involved in the decoding process. Networks usually (depending on the actual architecture) require input size that has side lengths as integer multiples of the largest stride (8, 16, 32, etc.).

There are two main ways:

  1. Resize input to the nearest feasible size.
  2. Pad the input to the next larger feasible size.

I prefer (2) because (1) can cause small changes in the pixel level for all the pixels, leading to unnecessary blurriness. Note that we usually need to recover the original shape afterward in both methods.

My favorite code snippet for this task (symmetric padding for height/width):

import torch
import torch.nn.functional as F

def pad_to(x, stride):
    h, w = x.shape[-2:]

    if h % stride > 0:
        new_h = h + stride - h % stride
    else:
        new_h = h
    if w % stride > 0:
        new_w = w + stride - w % stride
    else:
        new_w = w
    lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
    lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
    pads = (lw, uw, lh, uh)

    # zero-padding by default.
    # See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
    out = F.pad(x, pads, "constant", 0)

    return out, pads

def unpad(x, pad):
    if pad[2]+pad[3] > 0:
        x = x[:,:,pad[2]:-pad[3],:]
    if pad[0]+pad[1] > 0:
        x = x[:,:,:,pad[0]:-pad[1]]
    return x

A test snippet:

x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network 
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape

print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)

Output:

Original:  torch.Size([4, 3, 1080, 1920])
Padded:  torch.Size([4, 3, 1088, 1920])
Recovered:  torch.Size([4, 3, 1080, 1920])

Reference: https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33

Owner answered 3/2, 2021 at 15:35 Comment(1)
Thanks! Padding seems to be a better option.Waligore

© 2022 - 2024 — McMap. All rights reserved.