I'm working on implementing a semantic segmentation network in Tensorflow, and I'm trying to figure out how to write out summary images of the labels during training. I want to encode the images in a similar style to the class segmentation annotations used in the Pascal VOC dataset.
For example, let's assume I have a network that trains on a batch size of 1 with 4 classes. The networks final predictions have shape [1, 3, 3, 4]
Essentially I want to take the output predictions and run it through argmin
to get a tensor containing the most likely class at each point in the output:
[[[0, 1, 3],
[2, 0, 1],
[3, 1, 2]]]
The annotated images use a color palette of 255 colors to encode labels. I have a tensor containing all the color triples:
[[ 0, 0, 0],
[128, 0, 0],
[ 0, 128, 0],
[128, 128, 0],
[ 0, 0, 128],
...
[224, 224, 192]]
How could I obtain a tensor of shape [1, 3, 3, 3]
(a single 3x3 color image) that indexes into the color palette using the values obtained from argmin
?
[[palette[0], palette[1], palette[3]],
[palette[2], palette[0], palette[1]],
[palette[3], palette[1], palette[2]]]
I could easily wrap some numpy and PIL code in tf.py_func
but I'm wondering if there is a pure Tensorflow way of obtaining this result.
EDIT:
For those curious, this is the solution I got using just numpy. It works quite well, but I still dislike the use of tf.py_func
:
import numpy as np
import tensorflow as tf
def voc_colormap(N=256):
bitget = lambda val, idx: ((val & (1 << idx)) != 0)
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r |= (bitget(c, 0) << 7 - j)
g |= (bitget(c, 1) << 7 - j)
b |= (bitget(c, 2) << 7 - j)
c >>= 3
cmap[i, :] = [r, g, b]
return cmap
VOC_COLORMAP = voc_colormap()
def grayscale_to_voc(input, name="grayscale_to_voc"):
return tf.py_func(grayscale_to_voc_impl, [input], tf.uint8, stateful=False, name=name)
def grayscale_to_voc_impl(input):
return np.squeeze(VOC_COLORMAP[input])