Efficient image dilation in TensorFlow
Asked Answered
S

1

2

I'm looking for an efficient way to implement morphological image dilation with a square kernel in TensorFlow. It looks like the obvious ways of doing it are extremely inefficient compared to what they could be, as shown by OpenCV. See results of running source code pasted at bottom - even the fastest method is about 30x slower than OpenCV. These are from MacBook Air with M1 chipset.

Dilation of 640x480 image with a 25x25 kernel took: 
  0.61ms using opencv
  545.40ms using tf.nn.max_pool2d
  228.66ms using tf.nn.dilation2d naively
  17.63ms using tf.nn.dilation2d with row-col

Question: Does anyone know of a way to do image-dilation with TensorFlow that isn't extremely inefficient?

Source Code for current solutions:

import numpy as np
import cv2
import tensorflow as tf
import time


def tf_dilate(heatmap, width: int, method: str = 'rowcol'):
    """ Dilate the heatmap with a square kernel """
    if method=='maxpool':
        return tf.nn.max_pool2d(heatmap[None, :, :, None], ksize=width, padding='SAME', strides=(1, 1))[0, :, :, 0]
    elif method == 'naive_dilate':
        return tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((width, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))[0, :, :, 0]
    elif method == 'rowcol_dilate':

        row_dilation = tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((1, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        full_dilation = tf.nn.dilation2d(row_dilation, filters=tf.zeros((width, 1, 1), dtype=heatmap.dtype),
                                         strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        return full_dilation[0, :, :, 0]
    else:
        raise NotImplementedError(f'No method {method}')


def test_dilation_options(img_shape=(480, 640), kernel_size=25):

    img = np.random.randn(*img_shape).astype(np.float32)**2

    def get_result_and_time(version: str):

        tf_image = tf.constant(img, dtype=tf.float32)
        t_start = time.time()
        if version=='opencv':
            result = cv2.dilate(img, kernel=np.ones((kernel_size, kernel_size), dtype=np.float32))
            return time.time()-t_start, result
        else:
            result = tf_dilate(tf_image, width=kernel_size, method=version)
            return time.time()-t_start, result.numpy()

    t_opencv, result_opencv = get_result_and_time('opencv')
    t_maxpool, result_maxpool = get_result_and_time('maxpool')
    t_naive_dilate, result_naive_dilate = get_result_and_time('naive_dilate')
    t_rowcol_dilate, result_rowcol_dilate = get_result_and_time('rowcol_dilate')
    assert np.array_equal(result_opencv, result_maxpool), "Maxpool result did not match opencv result"
    assert np.array_equal(result_opencv, result_naive_dilate), "Naive dilation result did not match opencv result"
    assert np.array_equal(result_opencv, result_rowcol_dilate), "Row-col dilation result did not match opencv result"
    print(f'Dilation of {img_shape[1]}x{img_shape[0]} image with a {kernel_size}x{kernel_size} kernel took: '
          f'\n  {t_opencv*1000:.2f}ms using opencv'
          f'\n  {t_maxpool*1000:.2f}ms using tf.nn.max_pool2d'
          f'\n  {t_naive_dilate*1000:.2f}ms using tf.nn.dilation2d naively'
          f'\n  {t_rowcol_dilate*1000:.2f}ms using tf.nn.dilation2d with row-col'
          )


if __name__ == '__main__':
    test_dilation_options()
Simard answered 23/6, 2022 at 16:50 Comment(5)
do you need the operation to be differentiable?Obliging
A dilation with a 1D structuring element (as in your row-col version) can be computed with 3 comparisons per pixel, independently of the size of the kernel. Obviously TensorFlow does not implement that algorithm, it likely implements the naive algorithm with width-1 comparisons per pixel.Calculus
Addendum to my previous comment: TensorFlow implements the dilation with a gray-value structuring element, which can not be implemented with 3 comparisons per pixel; that is only true for the binary SE case, which is what OpenCV implements. The TensorFlow dilation needs to do a lot more work than the OpenCV dilation, because it's more generic.Calculus
@Christoph No need for differentiability in my case.Simard
@Chris - thanks - I have a feeling it's possible using tf.reduce_max with some fancy indexing.Simard
S
2

Well if you're fine with an approximate solution, there is always "poor man's dilate", which approximates dilation using a weighted local average (a box filter), where weights are taken by exponentiating the image. It's O((H+K)*(W+K)) where W,H are image width, height, and K is kernel size.

It also has the benefit that the gradient flows not just through the local maxima, but through the contenders to the thrown as well.

See code:

TensorImage = NewType('TensorImage', tf.Tensor)  # A (height, width, n_colors) uint8 image
TensorFloatImage = NewType('TensorFloatImage', tf.Tensor)
TensorHeatmap = NewType('TensorHeatmap', tf.Tensor)  # A (height, width) heatmap

def tf_box_filter(image: Union[TensorImage, TensorFloatImage, TensorHeatmap], width: int, normalize: bool = True, weights: Optional[TensorHeatmap] = None,
                  weight_eps: float = 1e-6, norm_weights: bool = True):
    image = tf.cast(image, tf.float32) if image.dtype != tf.float64 else image
    if weights is not None:
        if norm_weights:
            weights = weights/(width**2)
        if len(image.shape) == 3:
            weights = weights[:, :, None]  # Lets us broadcast weights against image

        image = image * weights

    lwidth = width // 2 + 1
    rwidth = width - lwidth

    integral_image_container = tf.pad(image,
                                      paddings=[(lwidth, rwidth), (lwidth, rwidth)] + [(0, 0)] * (len(image.shape) - 2))
    integral_image_container = tf.cumsum(tf.cumsum(integral_image_container, axis=0), axis=1)
    box_image = integral_image_container[width:, width:] \
                - integral_image_container[width:, :-width] \
                - integral_image_container[:-width, width:] \
                + integral_image_container[:-width, :-width]

    if not normalize:
        return box_image if (weights is None or not norm_weights) else box_image*(width**2)
    elif weights is None:
        return box_image / (width ** 2)
    else:
        box_weights = tf_box_filter(weights, width=width, normalize=False)
        return (box_image + weight_eps) / (box_weights + weight_eps)


def tf_poor_mans_dilate(heatmap: TensorHeatmap, width: int, power: int = 4, cast_to_64 = False) -> TensorHeatmap:
    """ A 'poor man's' version of dilation, whise runtime is O((image_height+kernel_width), (image_width+kernel_width))"""
    if cast_to_64:
        heatmap = tf.cast(heatmap, tf.float64)
    return tf_box_filter(heatmap, width, weights=heatmap**power, weight_eps=1e-9)


A test reveals it to be around 3x faster than the solutions in the question (more when kernel is large).


def test_poor_mans_dilate(show=False):
    """ Can be faster for large images and kernels

    Dilating image of shape (1280, 720) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.09009s
        Poor Man's Dilate: Elapsed time is 0.02953s

    Dilating image of shape (640, 480) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.03089s
        Poor Man's Dilate: Elapsed time is 0.008736s

    Dilating image of shape (640, 480) with kernel of shape 20x20
        Real Dilate: Elapsed time is 0.01475s
        Poor Man's Dilate: Elapsed time is 0.009809s
    """
    img = tf.random.Generator.from_seed(1234).normal((640, 480))**4
    width = 20
    print(f'Dilating image of shape {img.shape} with kernel of shape {width}x{width}')
    with profile_context('Real Dilate', print_result=True):
        dil_img = tf_dilate(img, width=width)
    with profile_context("Poor Man's Dilate", print_result=True):
        poor_dil_img = tf_poor_mans_dilate(img, width=width)

    assert np.allclose(dil_img.numpy().max(), poor_dil_img.numpy().max(), rtol=0.001)
Simard answered 7/7, 2022 at 5:42 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.