Pixel-wise loss weight for image segmentation in Keras
Asked Answered
J

2

7

I am currently using a modified version of the U-Net (https://arxiv.org/pdf/1505.04597.pdf) to segment cell organelles in microscopy images. Since I am using Keras, I took the code from https://github.com/zhixuhao/unet. However, in this version no weight map is implemented to force the network to learn the border pixels.

The results that I have obtained so far are quite good, but the network fails to separate objects that are close to each other. So I want to try and make use of the weight map mentioned in the paper. I have been able to generate the weight map (based on the given formula) for each label image, but I was unable to find out how to use this weight map to train my network and thus solve the above mentioned problem.

Do weight maps and label images have to be combined somehow or is there a Keras function that will allow me to make use of the weight maps? I am Biologist, who only recently started to work with neural networks, so my understanding is still limited. Any help or advice would be greatly appreciated.

Jeremyjerez answered 9/5, 2018 at 14:5 Comment(0)
C
14

In case it is still relevant: I needed to solve this recently. You can paste the code below into a Jupyter notebook to see how it works.

%matplotlib inline
import numpy as np
from skimage.io import imshow
from skimage.measure import label
from scipy.ndimage.morphology import distance_transform_edt
import numpy as np

def generate_random_circles(n = 100, d = 256):
    circles = np.random.randint(0, d, (n, 3))
    x = np.zeros((d, d), dtype=int)
    f = lambda x, y: ((x - x0)**2 + (y - y0)**2) <= (r/d*10)**2
    for x0, y0, r in circles:
        x += np.fromfunction(f, x.shape)
    x = np.clip(x, 0, 1)

    return x

def unet_weight_map(y, wc=None, w0 = 10, sigma = 5):

    """
    Generate weight maps as specified in the U-Net paper
    for boolean mask.

    "U-Net: Convolutional Networks for Biomedical Image Segmentation"
    https://arxiv.org/pdf/1505.04597.pdf

    Parameters
    ----------
    mask: Numpy array
        2D array of shape (image_height, image_width) representing binary mask
        of objects.
    wc: dict
        Dictionary of weight classes.
    w0: int
        Border weight parameter.
    sigma: int
        Border width parameter.

    Returns
    -------
    Numpy array
        Training weights. A 2D array of shape (image_height, image_width).
    """

    labels = label(y)
    no_labels = labels == 0
    label_ids = sorted(np.unique(labels))[1:]

    if len(label_ids) > 1:
        distances = np.zeros((y.shape[0], y.shape[1], len(label_ids)))

        for i, label_id in enumerate(label_ids):
            distances[:,:,i] = distance_transform_edt(labels != label_id)

        distances = np.sort(distances, axis=2)
        d1 = distances[:,:,0]
        d2 = distances[:,:,1]
        w = w0 * np.exp(-1/2*((d1 + d2) / sigma)**2) * no_labels
    else:
        w = np.zeros_like(y)
    if wc:
        class_weights = np.zeros_like(y)
        for k, v in wc.items():
            class_weights[y == k] = v
        w = w + class_weights
    return w

y = generate_random_circles()

wc = {
    0: 1, # background
    1: 5  # objects
}

w = unet_weight_map(y, wc)

imshow(w)
Corrigan answered 6/11, 2018 at 21:5 Comment(0)
G
-2

I think you want to use class_weight in Keras. This is actually simple to introduce in your model if you have already calculated the class weights.

  1. Create a dictionary with your class labels and their associated weights. For example

    class_weight = {0: 10.9,
            1: 20.8,
            2: 1.0,
            3: 50.5}
    
  2. Or create a 1D Numpy array of the same length as your number of classes. For example

    class_weight = [10.9, 20.8, 1.0, 50.5]
    
  3. Pass this parameter during training in your model.fit or model.fit_generator

    model.fit(x, y, batch_size=batch_size, epochs=num_epochs, verbose=1, class_weight=class_weight)
    

You can look up the Keras documentation for more details here.

Gillman answered 9/5, 2018 at 14:47 Comment(3)
Thanks for your reply. What exactly does the term class refer to in this case? Is it the number of labels within my images? The weight map is a distance map (of the same shape as the input image, so it is a 2D array), in which the borders between cells contain higher pixel intensities to add more weight to the border pixels. I have nonetheless tried the approach you described above, but it gave me a final image in which the test accuracy had dropped to less than 4% (prior to 80% before). Any suggestions?Jeremyjerez
Yes, class means the different labels into which you want to segment the image. I will have a look at the paper again to check and get back to you.Gillman
The distance map you mentioned is for the separation border. The frequency distribution of the pixels per label class (w_c in the paper) can be implemented by the method I mentioned in the answer.Gillman

© 2022 - 2024 — McMap. All rights reserved.