I am using a Unet based model to perform image segmentation on a biomedical image. Each image is 224x224 and I have four classes including the background class. Each mask is sized as (224x224x4) and so my generator creates batches of numpy arrays sized (16x224x224x4). I recast the values for the mask as either 1 or 0 so for each class a 1 is present in the relevant channel. The image is also scaled by 1/255. I use a dice score as the performance metric during training and 1-dice score as the loss function. I seem to be getting scores up to 0.89 during training but I'm finding that when I predict on my test set I am always predicting the background class. I'm only training for 10 epochs on a few hundred images (although I do have access to far more) which may be affecting the model but I would have thought I'd still get predictions of other classes so i'm assuming the main problem is a class imbalance. From looking online the sample_weight argument could be the answer but I'm not sure how I'm meant to implement the actual weight part? presumably I need to apply the weights to the array of pixels at some point in the model using a layer but i'm not sure how. Any help would be much appreciated?
class DataGenerator(keras.utils.Sequence):
def __init__(self, imgIds, maskIds, imagePath, maskPath, batchSize=16, imageSize = (224, 224, 3), nClasses=2, shuffle=False):
self.imgIds = imgIds
self.maskIds = maskIds
self.imagePath = imagePath
self.maskPath = maskPath
self.batchSize = batchSize
self.imageSize = imageSize
self.nClasses = nClasses
self.shuffle = shuffle
def __load__(self, imgName, maskName):
img = cv2.imread(os.path.join(self.imagePath,imgName))
img = cv2.resize(img, (self.imageSize[0], self.imageSize[1]))
mask = cv2.imread(os.path.join(self.maskPath,maskName))
mask = np.dstack((mask, np.zeros((4000, 4000))))
mask[:,:,3][mask[:,:,0]==0]=255
mask = mask.astype(np.bool)
mask = img_as_bool(resize(mask, (self.imageSize[0], self.imageSize[1])))
mask = mask.astype('uint8')
img = img/255.0
mask = mask
return (img, mask)
def __getitem__(self, index):
if(index+1)*self.batchSize > len(self.imgIds):
self.batchSize = len(self.imgIds) - index*self.batchSize
batchImgs = self.imgIds[self.batchSize*index:self.batchSize*(index+1)]
batchMasks = self.maskIds[self.batchSize*index:self.batchSize*(index+1)]
batchfiles = [self.__load__(imgFile, maskFile) for imgFile, maskFile in
zip(batchImgs, batchMasks)]
images, masks = zip(*batchfiles)
return np.array(list(images)), np.array(list(masks))
def __len__(self):
return int(np.ceil(len(self.imgIds)/self.batchSize))
class Unet():
def __init__(self, imgSize):
self.imgSize = imgSize
def convBlocks(self, x, filters, kernelSize=(3,3), padding='same', strides=1):
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
x = keras.layers.Conv2D(filters, kernelSize, padding=padding, strides=strides)(x)
return x
def identity(self, x, xInput, f, padding='same', strides=1):
skip = keras.layers.Conv2D(f, kernel_size=(1, 1), padding=padding, strides=strides)(xInput)
skip = keras.layers.BatchNormalization()(skip)
output = keras.layers.Add()([skip, x])
return output
def residualBlock(self, xIn, f, stride):
res = self.convBlocks(xIn, f, strides=stride)
res = self.convBlocks(res, f, strides=1)
output = self.identity(res, xIn, f, strides=stride)
return output
def upSampling(self, x, xInput):
x = keras.layers.UpSampling2D((2,2))(x)
x = keras.layers.Concatenate()([x, xInput])
return x
def encoder(self, x, filters, kernelSize=(3,3), padding='same', strides=1):
e1 = keras.layers.Conv2D(filters[0], kernelSize, padding=padding, strides=strides)(x)
e1 = self.convBlocks(e1, filters[0])
shortcut = keras.layers.Conv2D(filters[0], kernel_size=(1, 1), padding=padding, strides=strides)(x)
shortcut = keras.layers.BatchNormalization()(shortcut)
e1Output = keras.layers.Add()([e1, shortcut])
e2 = self.residualBlock(e1Output, filters[1], stride=2)
e3 = self.residualBlock(e2, filters[2], stride=2)
e4 = self.residualBlock(e3, filters[3], stride=2)
e5 = self.residualBlock(e4, filters[4], stride=2)
return e1Output, e2, e3, e4, e5
def bridge(self, x, filters):
b1 = self.convBlocks(x, filters, strides=1)
b2 = self.convBlocks(b1, filters, strides=1)
return b2
def decoder(self, b2, e1, e2, e3, e4, filters, kernelSize=(3,3), padding='same', strides=1):
x = self.upSampling(b2, e4)
d1 = self.convBlocks(x, filters[4])
d1 = self.convBlocks(d1, filters[4])
d1 = self.identity(d1, x, filters[4])
x = self.upSampling(d1, e3)
d2 = self.convBlocks(x, filters[3])
d2 = self.convBlocks(d2, filters[3])
d2 = self.identity(d2, x, filters[3])
x = self.upSampling(d2, e2)
d3 = self.convBlocks(x, filters[2])
d3 = self.convBlocks(d3, filters[2])
d3 = self.identity(d3, x, filters[2])
x = self.upSampling(d3, e1)
d4 = self.convBlocks(x, filters[1])
d4 = self.convBlocks(d4, filters[1])
d4 = self.identity(d4, x, filters[1])
return d4
def ResUnet(self, filters = [16, 32, 64, 128, 256]):
inputs = keras.layers.Input((224, 224, 3))
e1, e2, e3, e4, e5 = self.encoder(inputs, filters)
b2 = self.bridge(e5, filters[4])
d4 = self.decoder(b2, e1, e2, e3, e4, filters)
x = keras.layers.Conv2D(4, (1, 1), padding='same', activation='softmax')(d4)
model = keras.models.Model(inputs, x)
return model
imagePath = 'output/t2'
maskPath = 'output/t1'
imgIds = glob.glob(os.path.join(imagePath, '*'))
maskIds = glob.glob(os.path.join(maskPath, '*'))
imgIds = [os.path.basename(f) for f in imgIds]
maskIds = [os.path.basename(f) for f in maskIds]
trainImgIds = imgIds[:300]
trainMaskIds = maskIds[:300]
validImgIds = imgIds[300:350]
validMaskIds = maskIds[300:350]
trainGenerator = DataGenerator(trainImgIds, trainMaskIds, imagePath, maskPath, **params)
validGenerator = DataGenerator(validImgIds, validMaskIds, imagePath, maskPath)
trainSteps = len(trainImgIds)//trainGenerator.batchSize
validSteps = len(validImgIds)//validGenerator.batchSize
unet = Unet(224)
model = unet.ResUnet()
model.summary()
adam = keras.optimizers.Adam()
model.compile(optimizer=adam, loss=dice_coef_loss, metrics=[dice_coef])
hist = model.fit_generator(trainGenerator, validation_data=validGenerator,
steps_per_epoch=trainSteps, validation_steps=validSteps,
verbose=1, epochs=6)