Weighting samples in multiclass image segmentation using keras
Asked Answered
S

2

7

I am using a Unet based model to perform image segmentation on a biomedical image. Each image is 224x224 and I have four classes including the background class. Each mask is sized as (224x224x4) and so my generator creates batches of numpy arrays sized (16x224x224x4). I recast the values for the mask as either 1 or 0 so for each class a 1 is present in the relevant channel. The image is also scaled by 1/255. I use a dice score as the performance metric during training and 1-dice score as the loss function. I seem to be getting scores up to 0.89 during training but I'm finding that when I predict on my test set I am always predicting the background class. I'm only training for 10 epochs on a few hundred images (although I do have access to far more) which may be affecting the model but I would have thought I'd still get predictions of other classes so i'm assuming the main problem is a class imbalance. From looking online the sample_weight argument could be the answer but I'm not sure how I'm meant to implement the actual weight part? presumably I need to apply the weights to the array of pixels at some point in the model using a layer but i'm not sure how. Any help would be much appreciated?

class DataGenerator(keras.utils.Sequence):
     def __init__(self, imgIds, maskIds, imagePath, maskPath, batchSize=16, imageSize = (224, 224, 3), nClasses=2, shuffle=False):
       self.imgIds = imgIds
       self.maskIds = maskIds
       self.imagePath = imagePath
       self.maskPath = maskPath
       self.batchSize = batchSize
       self.imageSize = imageSize
       self.nClasses = nClasses
       self.shuffle = shuffle


     def __load__(self, imgName, maskName):

       img = cv2.imread(os.path.join(self.imagePath,imgName))
       img = cv2.resize(img, (self.imageSize[0], self.imageSize[1]))

       mask = cv2.imread(os.path.join(self.maskPath,maskName))
       mask = np.dstack((mask, np.zeros((4000, 4000))))

       mask[:,:,3][mask[:,:,0]==0]=255
       mask = mask.astype(np.bool)
       mask = img_as_bool(resize(mask, (self.imageSize[0], self.imageSize[1])))
       mask = mask.astype('uint8')

       img = img/255.0
       mask = mask

       return (img, mask)


    def __getitem__(self, index):

       if(index+1)*self.batchSize > len(self.imgIds):
          self.batchSize = len(self.imgIds) - index*self.batchSize

       batchImgs = self.imgIds[self.batchSize*index:self.batchSize*(index+1)]
       batchMasks = self.maskIds[self.batchSize*index:self.batchSize*(index+1)]

       batchfiles = [self.__load__(imgFile, maskFile) for imgFile, maskFile in 
       zip(batchImgs, batchMasks)]

       images, masks = zip(*batchfiles)

       return np.array(list(images)), np.array(list(masks))


   def __len__(self):
       return int(np.ceil(len(self.imgIds)/self.batchSize))


class Unet():
   def __init__(self, imgSize):
       self.imgSize = imgSize


   def convBlocks(self, x, filters, kernelSize=(3,3), padding='same', strides=1):

       x = keras.layers.BatchNormalization()(x)
       x = keras.layers.Activation('relu')(x)
       x = keras.layers.Conv2D(filters, kernelSize, padding=padding, strides=strides)(x)

       return x


   def identity(self, x, xInput, f, padding='same', strides=1):

      skip = keras.layers.Conv2D(f, kernel_size=(1, 1), padding=padding, strides=strides)(xInput)
      skip = keras.layers.BatchNormalization()(skip)
      output = keras.layers.Add()([skip, x])

      return output


    def residualBlock(self, xIn, f, stride):

      res = self.convBlocks(xIn, f, strides=stride)
      res = self.convBlocks(res, f, strides=1)
      output = self.identity(res, xIn, f, strides=stride)

      return output


    def upSampling(self, x, xInput):

      x = keras.layers.UpSampling2D((2,2))(x)
      x = keras.layers.Concatenate()([x, xInput])

      return x


    def encoder(self, x, filters, kernelSize=(3,3), padding='same', strides=1):

      e1 = keras.layers.Conv2D(filters[0], kernelSize, padding=padding, strides=strides)(x)
      e1 = self.convBlocks(e1, filters[0])

      shortcut = keras.layers.Conv2D(filters[0], kernel_size=(1, 1), padding=padding, strides=strides)(x)
      shortcut = keras.layers.BatchNormalization()(shortcut)
      e1Output = keras.layers.Add()([e1, shortcut])

      e2 = self.residualBlock(e1Output, filters[1], stride=2)
      e3 = self.residualBlock(e2, filters[2], stride=2)
      e4 = self.residualBlock(e3, filters[3], stride=2)
      e5 = self.residualBlock(e4, filters[4], stride=2)

      return e1Output, e2, e3, e4, e5


  def bridge(self, x, filters):

      b1 = self.convBlocks(x, filters, strides=1)
      b2 = self.convBlocks(b1, filters, strides=1)

      return b2


  def decoder(self, b2, e1, e2, e3, e4, filters, kernelSize=(3,3), padding='same', strides=1):

      x = self.upSampling(b2, e4)
      d1 = self.convBlocks(x, filters[4])
      d1 = self.convBlocks(d1, filters[4])
      d1 = self.identity(d1, x, filters[4])

      x = self.upSampling(d1, e3)
      d2 = self.convBlocks(x, filters[3])
      d2 = self.convBlocks(d2, filters[3])
      d2 = self.identity(d2, x, filters[3])

      x = self.upSampling(d2, e2)
      d3 = self.convBlocks(x, filters[2])
      d3 = self.convBlocks(d3, filters[2])
      d3 = self.identity(d3, x, filters[2])

      x = self.upSampling(d3, e1)
      d4 = self.convBlocks(x, filters[1])
      d4 = self.convBlocks(d4, filters[1])
      d4 = self.identity(d4, x, filters[1])

      return d4 


  def ResUnet(self, filters = [16, 32, 64, 128, 256]):

      inputs = keras.layers.Input((224, 224, 3))

      e1, e2, e3, e4, e5 = self.encoder(inputs, filters)
      b2 = self.bridge(e5, filters[4])
      d4 = self.decoder(b2, e1, e2, e3, e4, filters)

      x = keras.layers.Conv2D(4, (1, 1), padding='same', activation='softmax')(d4)
      model = keras.models.Model(inputs, x)

      return model


imagePath = 'output/t2'
maskPath = 'output/t1'

imgIds = glob.glob(os.path.join(imagePath, '*'))
maskIds = glob.glob(os.path.join(maskPath, '*'))

imgIds = [os.path.basename(f) for f in imgIds]
maskIds = [os.path.basename(f) for f in maskIds]

trainImgIds = imgIds[:300]
trainMaskIds = maskIds[:300]
validImgIds = imgIds[300:350]
validMaskIds = maskIds[300:350]

trainGenerator = DataGenerator(trainImgIds, trainMaskIds, imagePath, maskPath, **params)
validGenerator = DataGenerator(validImgIds, validMaskIds, imagePath, maskPath)

trainSteps = len(trainImgIds)//trainGenerator.batchSize
validSteps = len(validImgIds)//validGenerator.batchSize

unet = Unet(224)
model = unet.ResUnet()
model.summary()

adam = keras.optimizers.Adam()
model.compile(optimizer=adam, loss=dice_coef_loss, metrics=[dice_coef])

hist = model.fit_generator(trainGenerator, validation_data=validGenerator, 
steps_per_epoch=trainSteps, validation_steps=validSteps, 
                verbose=1, epochs=6)
Stancil answered 16/2, 2020 at 20:52 Comment(1)
The solution can be found here: [Keras - how to use class_weight with 3D data #3653][1] If you figure out how to actually implement this, the syntax, please put it in an answer here as I didn't have time to figure this out and could use it too. [1]: github.com/keras-team/keras/issues/3653Blackbeard
S
1

To follow up on this, I got it to work using sample_weight. It is quite nice if you know what you have to do. Unfortunately, the documentation is not really clear on this, presumably because this feature was originally added for time series data.

  • You need to reshape your 2D image-sized output as a vector before the loss function when you specify your model.
  • Use sample_weight_mode="temporal" when you compile the model. This will allow you to pass in a weight matrix for training where each row represents the weight vector for a single sample.

I hope that helps.

Sandi answered 24/2, 2020 at 11:52 Comment(1)
Where would I pass in the weight matrix? And does the weight vector equal the shape of our 2D image-sized output, so each element in the row contains a weight that belongs to the corresponding pixel class?Stancil
A
1

I am using Keras
It is NOT the sample weights in particular.
First you would better convert to gray scale images AND
You need to redesign your problem architecture, like this:
Build Two models:

1. A segmentation model - regardless of the class type - that detects and segments an image pixels or regions of interest (ROI), you can extract it as patches.
Let's say your ROI are the pixels of value 1 (positives), then most probably the background of value 0 (negative) are the dominant class of pixels, hence it's unbalanced data, so you need to use a loss function that penalizes false negatives more than false positives, something like balanced_cross_entropy:

def balanced_cross_entropy(beta):
  def convert_to_logits(y_pred):
      y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())

      return tf.log(y_pred / (1 - y_pred))

  def loss(y_true, y_pred):
    y_pred = convert_to_logits(y_pred)
    pos_weight = beta / (1 - beta)
    loss = tf.nn.weighted_cross_entropy_with_logits(logits=y_pred, targets=y_true, pos_weight=pos_weight)

    # or reduce_sum and/or axis=-1
    return tf.reduce_mean(loss * (1 - beta))

  return loss

Then in your model, use 20% weight for negative pixels and 80% for positive ones, or adjust it as you see fit.

model.compile(optimizer=Adam(), loss=balanced_cross_entropy(0.2), metrics=["accuracy"])
  1. A classifier model which process the ROI detected or patches extracted by the autoencoder model, and detect the type of class among the (now) 3 classes, after training on labelled patches.

  2. For the first part you can optionally add a threshold module.

  3. the sample weights will be useful in the classifier model if some of your classes data are under-represented, let's say your class 3 (index 2) is rare, then you assign more weight to the images of class 4 or you can use class_weight:

    class_weights = {0: 0.1, 1: 0.1, 2: 0.8}
    model.fit_generator(train_gen, class_weight=class_weights)
    

    You may also use data augmentation techniques

  4. To load a saved model with the customized loss function, use custom objects:

     model = load_model(filePath, 
              custom_objects={'loss': balanced_cross_entropy(0.2)})
    
Avon answered 25/2, 2020 at 23:34 Comment(2)
I find this slightly confusing. My problem is multi-class and I should be able to train the model to find multiple classes whilst considering there needs to be a weighted loss function to offset that imbalance across all of the classes. Where does the autoencoder come into this, is this to build the masks? What is a threshold module? And the classification model I use at the end needs to be fully convolutional and I though for segmentation problems we can't use the class_weight parameter?Stancil
Auto-Encoder is U-NET, use it to segment the area of interest, you can use a weighted loss function, other wise your model will classify the pixels as background. Then you get the segments and pass it to another deep convolutional NN that classify it into one or more of the classes.Avon

© 2022 - 2024 — McMap. All rights reserved.