Disclaimer
The below answer isn't the actual solution to the above question. I misunderstood the above query. But I'm leaving this response in case of future readers find it useful.
Input
import cv2
import matplotlib.pyplot as plt
input_img = cv2.imread('/content/2.jpeg')
print(input_img.shape) # (719, 640, 3)
plt.imshow(input_img)
Slice and Stitch
The following functionality is adopted from here. More details and discussion can be found here.. Apart from the original code, we bring together the necessary functionality and put them in a single class (ImageSliceRejoin
).
# ref: https://github.com/idealo/image-super-resolution
class ImageSliceRejoin:
def pad_patch(self, image_patch, padding_size, channel_last=True):
""" Pads image_patch with padding_size edge values. """
if channel_last:
return np.pad(
image_patch,
((padding_size, padding_size),
(padding_size, padding_size), (0, 0)),
'edge',
)
else:
return np.pad(
image_patch,
((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
'edge',
)
# function to split the image into patches
def split_image_into_overlapping_patches(self, image_array, patch_size, padding_size=2):
""" Splits the image into partially overlapping patches.
The patches overlap by padding_size pixels.
Pads the image twice:
- first to have a size multiple of the patch size,
- then to have equal padding at the borders.
Args:
image_array: numpy array of the input image.
patch_size: size of the patches from the original image (without padding).
padding_size: size of the overlapping area.
"""
xmax, ymax, _ = image_array.shape
x_remainder = xmax % patch_size
y_remainder = ymax % patch_size
# modulo here is to avoid extending of patch_size instead of 0
x_extend = (patch_size - x_remainder) % patch_size
y_extend = (patch_size - y_remainder) % patch_size
# make sure the image is divisible into regular patches
extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
# add padding around the image to simplify computations
padded_image = self.pad_patch(extended_image, padding_size, channel_last=True)
xmax, ymax, _ = padded_image.shape
patches = []
x_lefts = range(padding_size, xmax - padding_size, patch_size)
y_tops = range(padding_size, ymax - padding_size, patch_size)
for x in x_lefts:
for y in y_tops:
x_left = x - padding_size
y_top = y - padding_size
x_right = x + patch_size + padding_size
y_bottom = y + patch_size + padding_size
patch = padded_image[x_left:x_right, y_top:y_bottom, :]
patches.append(patch)
return np.array(patches), padded_image.shape
# joing the patches
def stich_together(self, patches, padded_image_shape, target_shape, padding_size=4):
""" Reconstruct the image from overlapping patches.
After scaling, shapes and padding should be scaled too.
Args:
patches: patches obtained with split_image_into_overlapping_patches
padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
target_shape: shape of the final image
padding_size: size of the overlapping area.
"""
xmax, ymax, _ = padded_image_shape
# unpad patches
patches = patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
patch_size = patches.shape[1]
n_patches_per_row = ymax // patch_size
complete_image = np.zeros((xmax, ymax, 3))
row = -1
col = 0
for i in range(len(patches)):
if i % n_patches_per_row == 0:
row += 1
col = 0
complete_image[
row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
] = patches[i]
col += 1
return complete_image[0: target_shape[0], 0: target_shape[1], :]
Initiate Slicing
import numpy as np
isr = ImageSliceRejoin()
padding_size = 1
patches, p_shape = isr.split_image_into_overlapping_patches(
input_img,
patch_size=220,
padding_size=padding_size
)
patches.shape, p_shape, input_img.shape
((12, 222, 222, 3), (882, 662, 3), (719, 640, 3))
Verify
n = np.ceil(patches.shape[0] / 2)
plt.figure(figsize=(20, 20))
patch_size = patches.shape[1]
for i in range(patches.shape[0]):
patch = patches[i]
ax = plt.subplot(n, n, i + 1)
patch_img = np.reshape(patch, (patch_size, patch_size, 3))
plt.imshow(patch_img.astype("uint8"))
plt.axis("off")
Inference
I'm using the Image-Super-Resolution model for demonstration.
# import model
from ISR.models import RDN
model = RDN(weights='psnr-small')
# number of patches that will pass to model for inference:
# here, batch_size < len(patches)
batch_size = 2
for i in range(0, len(patches), batch_size):
# get some patches
batch = patches[i: i + batch_size]
# pass them to model to give patches output
batch = model.model.predict(batch)
# save the output patches
if i == 0:
collect = batch
else:
collect = np.append(collect, batch, axis=0)
Now, the collect
holds the output of each patch from the model.
patches.shape, collect.shape
((12, 222, 222, 3), (12, 444, 444, 3))
Rejoin Patches
scale = 2
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
scaled_image_shape = tuple(np.multiply(input_img.shape[0:2], scale)) + (3,)
sr_img = isr.stich_together(
collect,
padded_image_shape=padded_size_scaled,
target_shape=scaled_image_shape,
padding_size=padding_size * scale,
)
Verify
print(input_img.shape, sr_img.shape)
# (719, 640, 3) (1438, 1280, 3)
fig, ax = plt.subplots(1,2)
fig.set_size_inches(18.5, 10.5)
ax[0].imshow(input_img)
ax[1].imshow(sr_img.astype('uint8'))