It is a very common problem in segmentation networks where skip-connections are often involved in the decoding process. Networks usually (depending on the actual architecture) require input size that has side lengths as integer multiples of the largest stride (8, 16, 32, etc.).
There are two main ways:
- Resize input to the nearest feasible size.
- Pad the input to the next larger feasible size.
I prefer (2) because (1) can cause small changes in the pixel level for all the pixels, leading to unnecessary blurriness. Note that we usually need to recover the original shape afterward in both methods.
My favorite code snippet for this task (symmetric padding for height/width):
import torch
import torch.nn.functional as F
def pad_to(x, stride):
h, w = x.shape[-2:]
if h % stride > 0:
new_h = h + stride - h % stride
else:
new_h = h
if w % stride > 0:
new_w = w + stride - w % stride
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pads = (lw, uw, lh, uh)
# zero-padding by default.
# See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
out = F.pad(x, pads, "constant", 0)
return out, pads
def unpad(x, pad):
if pad[2]+pad[3] > 0:
x = x[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
x = x[:,:,:,pad[0]:-pad[1]]
return x
A test snippet:
x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape
print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)
Output:
Original: torch.Size([4, 3, 1080, 1920])
Padded: torch.Size([4, 3, 1088, 1920])
Recovered: torch.Size([4, 3, 1080, 1920])
Reference: https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33
torch.nn.MaxPool2d
? You could use fixkernel size = (8, 5)
this would give you240 x 216
then you could pad the array to meet the require size256 x 256
without distorting the image too much. – CynthiaMaxpooling
, but could beAveragePooling
as well. – Cynthia256 x 256
chunk, this way you don't temper with the resolution. – Cynthia