Another approach is to threshold the heatmap, find connected components, and draw a box around each connected component.
An optional additional step is to downsample before finding connected components. This makes it faster and joins nearby components (resulting in fewer bounding boxes).
Here's some TensorFlow code I used to do that:
import tensorflow as tf
import tensorflow_addons as tfa
from dataclasses import dataclass
from typing import NewType
TensorIndexVector = NewType('TensorColor', tf.Tensor) # A vector of indices
TensorLTRBBoxes = NewType('TensorLTRBBoxes', tf.Tensor) # An array of boxes, specified by (Left, Right, Top, Bottom) pixel
@dataclass
class ConnectedComponentSegmenter(ITensorImageBoxer):
""" Creates bounding boxes from heatmap """
pre_pool_factor = 16 # We use downsampling to both reduce compute time and link nearby components
times_mean_thresh: float = 100 # Heat must be this much times the mean to qualify
pad: int = 10 # Pad the box by this many pixels (may cause box edge to fall outside image)
def find_bounding_boxes(self, heatmap: TensorHeatmap) -> Tuple[TensorIndexVector, TensorLTRBBoxes]:
""" Generate bounding boxes (represented as (box_ids, box_coords) from heatmap) """
salient_mask = tf.cast(heatmap / tf.reduce_mean(heatmap) > self.times_mean_thresh, tf.int32)
salient_mask_shrunk = tf_max_downsample(salient_mask, factor=self.pre_pool_factor)
component_label_image_shrunk = tfa.image.connected_components(salient_mask_shrunk)
component_label_image = tf.image.resize(component_label_image_shrunk[:, :, None], size=(heatmap.shape[0], heatmap.shape[1]), method=ResizeMethod.NEAREST_NEIGHBOR)[:, :, 0] \
* salient_mask
nonzero_ij = tf.where(component_label_image)
component_indices = tf.gather_nd(component_label_image, nonzero_ij)
num_segments = tf.reduce_max(component_label_image_shrunk) + 1
box_lefts = tf.math.unsorted_segment_min(nonzero_ij[:, 1], component_indices, num_segments=num_segments)
box_tops = tf.math.unsorted_segment_min(nonzero_ij[:, 0], component_indices, num_segments=num_segments)
box_rights = tf.math.unsorted_segment_max(nonzero_ij[:, 1], component_indices, num_segments=num_segments)
box_bottoms = tf.math.unsorted_segment_max(nonzero_ij[:, 0], component_indices, num_segments=num_segments)
boxes = tf.concat([box_lefts[1:, None], box_tops[1:, None], box_rights[1:, None], box_bottoms[1:, None]], axis=1)
return tf.range(len(boxes)), boxes