How to use Triton server "ensemble model" with 1:N input/output to create patches from large image?
Asked Answered
L

1

8

I am trying to feed a very large image into Triton server. I need to divide the input image into patches and feed the patches one by one into a tensorflow model. The image has a variable size, so the number of patches N is variable for each call.

I think a Triton ensemble model that calls the following steps would do the job:

  1. A python model (pre-process) to create the patches
  2. The segmentation model
  3. Finally another python model (post-process) to merge the output patches into a big output mask

However, for this, I would have to write a config. pbtxt file with 1:N and N:1 relation, meaning the ensemble scheduler needs to call the 2nd step multiple times and the 3rd once with the aggregated output.

Is this possible, or do I need to use some other technique?

Loveliesbleeding answered 26/4, 2021 at 11:7 Comment(4)
I was also trying something like this for the inference process. Are you doing this for that too?Allianora
Yes, it is for inference.Loveliesbleeding
It seems like a more general question - like, slice a large image into smaller patches and pass them sequentially to the model and save each patch results and at the end rejoin them back... However, AFAIK, segmentation like model needs global context or information to do the expected thing, if we slice the image into smaller patches, there would be a huge chance this spatial information may not properly be followed. Thus the outcome (rejoined image from patches) may not look natural (but it depends on though.)Allianora
We have long time experience working with patches - that is not the problem - inference works well in Tensorflow :-)Loveliesbleeding
A
2

Disclaimer

The below answer isn't the actual solution to the above question. I misunderstood the above query. But I'm leaving this response in case of future readers find it useful.


Input

import cv2 
import matplotlib.pyplot as plt

input_img = cv2.imread('/content/2.jpeg')
print(input_img.shape) # (719, 640, 3)
plt.imshow(input_img) 

Slice and Stitch

The following functionality is adopted from here. More details and discussion can be found here.. Apart from the original code, we bring together the necessary functionality and put them in a single class (ImageSliceRejoin).

# ref: https://github.com/idealo/image-super-resolution
class ImageSliceRejoin:
    def pad_patch(self, image_patch, padding_size, channel_last=True):
        """ Pads image_patch with padding_size edge values. """
        if channel_last:
            return np.pad(
                image_patch,
                ((padding_size, padding_size), 
                (padding_size, padding_size), (0, 0)),
                'edge',
            )
        else:
            return np.pad(
                image_patch,
                ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
                'edge',
            )

    # function to split the image into patches        
    def split_image_into_overlapping_patches(self, image_array, patch_size, padding_size=2):
        """ Splits the image into partially overlapping patches.
        The patches overlap by padding_size pixels.
        Pads the image twice:
            - first to have a size multiple of the patch size,
            - then to have equal padding at the borders.
        Args:
            image_array: numpy array of the input image.
            patch_size: size of the patches from the original image (without padding).
            padding_size: size of the overlapping area.
        """
        xmax, ymax, _ = image_array.shape
        x_remainder = xmax % patch_size
        y_remainder = ymax % patch_size
        
        # modulo here is to avoid extending of patch_size instead of 0
        x_extend = (patch_size - x_remainder) % patch_size
        y_extend = (patch_size - y_remainder) % patch_size
        
        # make sure the image is divisible into regular patches
        extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
        
        # add padding around the image to simplify computations
        padded_image = self.pad_patch(extended_image, padding_size, channel_last=True)
        
        xmax, ymax, _ = padded_image.shape
        patches = []
        
        x_lefts = range(padding_size, xmax - padding_size, patch_size)
        y_tops = range(padding_size, ymax - padding_size, patch_size)
        
        for x in x_lefts:
            for y in y_tops:
                x_left = x - padding_size
                y_top = y - padding_size
                x_right = x + patch_size + padding_size
                y_bottom = y + patch_size + padding_size
                patch = padded_image[x_left:x_right, y_top:y_bottom, :]
                patches.append(patch)
        
        return np.array(patches), padded_image.shape

    # joing the patches 
    def stich_together(self, patches, padded_image_shape, target_shape, padding_size=4):
        """ Reconstruct the image from overlapping patches.
        After scaling, shapes and padding should be scaled too.
        Args:
            patches: patches obtained with split_image_into_overlapping_patches
            padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
            target_shape: shape of the final image
            padding_size: size of the overlapping area.
        """
        xmax, ymax, _ = padded_image_shape

        # unpad patches
        patches = patches[:, padding_size:-padding_size, padding_size:-padding_size, :]

        patch_size = patches.shape[1]
        n_patches_per_row = ymax // patch_size
        complete_image = np.zeros((xmax, ymax, 3))

        row = -1
        col = 0
        for i in range(len(patches)):
            if i % n_patches_per_row == 0:
                row += 1
                col = 0
            complete_image[
            row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
            ] = patches[i]
            col += 1
        return complete_image[0: target_shape[0], 0: target_shape[1], :]

Initiate Slicing

import numpy as np 

isr = ImageSliceRejoin()
padding_size = 1

patches, p_shape = isr.split_image_into_overlapping_patches(
    input_img, 
    patch_size=220, 
    padding_size=padding_size
)

patches.shape, p_shape, input_img.shape
((12, 222, 222, 3), (882, 662, 3), (719, 640, 3))

Verify

n = np.ceil(patches.shape[0] / 2)
plt.figure(figsize=(20, 20))
patch_size = patches.shape[1]

for i in range(patches.shape[0]):
    patch = patches[i] 
    ax = plt.subplot(n, n, i + 1)
    patch_img = np.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.astype("uint8"))
    plt.axis("off")

enter image description here

Inference

I'm using the Image-Super-Resolution model for demonstration.

# import model
from ISR.models import RDN
model = RDN(weights='psnr-small')

# number of patches that will pass to model for inference: 
# here, batch_size < len(patches)
batch_size = 2

for i in range(0, len(patches), batch_size):
    # get some patches
    batch = patches[i: i + batch_size]

    # pass them to model to give patches output 
    batch = model.model.predict(batch)

    # save the output patches 
    if i == 0:
        collect = batch
    else:
        collect = np.append(collect, batch, axis=0)

Now, the collect holds the output of each patch from the model.

patches.shape, collect.shape
((12, 222, 222, 3), (12, 444, 444, 3))

Rejoin Patches

scale = 2
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
scaled_image_shape = tuple(np.multiply(input_img.shape[0:2], scale)) + (3,)

sr_img = isr.stich_together(
    collect,
    padded_image_shape=padded_size_scaled,
    target_shape=scaled_image_shape,
    padding_size=padding_size * scale,
)

Verify

print(input_img.shape, sr_img.shape)
# (719, 640, 3) (1438, 1280, 3)

fig, ax = plt.subplots(1,2)
fig.set_size_inches(18.5, 10.5)
ax[0].imshow(input_img)
ax[1].imshow(sr_img.astype('uint8'))

enter image description here

Allianora answered 7/5, 2021 at 13:58 Comment(2)
Thx for the answer. We already doing a similar technique using python and tensorflow. The question was how to do this with a Triton Server. The server will load a SavedModel and no python code. Python can be used on the TS python backend, but the output will be passed to only one inference call on the model, there seems no way to do a 1:N or N:1 relation between models in TS.Loveliesbleeding
I see. I didn't use the triton server, so my answer was not end-to-end. But as you mentioned the image slicing part of your problem in your question, I thought the above answer could add some value to the solution. -)Allianora

© 2022 - 2024 — McMap. All rights reserved.