It seems my implementation is incorrect and not sure what exactly I'm doing wrong:
Here is the histogram of my image:
So the threshold should be around 170 ish? I'm getting the threshold as 130.
Here is my code:
#Otsu in Python
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def load_image(file_name):
img = Image.open(file_name)
img.load()
bw = img.convert('L')
bw_data = np.array(bw).astype('int32')
BINS = np.array(range(0,257))
counts, pixels =np.histogram(bw_data, BINS)
pixels = pixels[:-1]
plt.bar(pixels, counts, align='center')
plt.savefig('histogram.png')
plt.xlim(-1, 256)
plt.show()
total_counts = np.sum(counts)
assert total_counts == bw_data.shape[0]*bw_data.shape[1]
return BINS, counts, pixels, bw_data, total_counts
def within_class_variance():
''' Here we will implement the algorithm and find the lowest Within- Class Variance:
Refer to this page for more details http://www.labbookpages.co.uk
/software/imgProc/otsuThreshold.html'''
for i in range(1,len(BINS), 1): #from one to 257 = 256 iterations
prob_1 = np.sum(counts[:i])/total_counts
prob_2 = np.sum(counts[i:])/total_counts
assert (np.sum(prob_1 + prob_2)) == 1.0
mean_1 = np.sum(counts[:i] * pixels[:i])/np.sum(counts[:i])
mean_2 = np.sum(counts[i:] * pixels[i:] )/np.sum(counts[i:])
var_1 = np.sum(((pixels[:i] - mean_1)**2 ) * counts[:i])/np.sum(counts[:i])
var_2 = np.sum(((pixels[i:] - mean_2)**2 ) * counts[i:])/np.sum(counts[i:])
if i == 1:
cost = (prob_1 * var_1) + (prob_2 * var_2)
keys = {'cost': cost, 'mean_1': mean_1, 'mean_2': mean_2, 'var_1': var_1, 'var_2': var_2, 'pixel': i-1}
print('first_cost',cost)
if (prob_1 * var_1) +(prob_2 * var_2) < cost:
cost =(prob_1 * var_1) +(prob_2 * var_2)
keys = {'cost': cost, 'mean_1': mean_1, 'mean_2': mean_2, 'var_1': var_1, 'var_2': var_2, 'pixel': i-1} #pixels is i-1 because BINS is starting from one
return keys
if __name__ == "__main__":
file_name = 'fish.jpg'
BINS, counts, pixels, bw_data, total_counts =load_image(file_name)
keys =within_class_variance()
print(keys['pixel'])
otsu_img = np.copy(bw_data).astype('uint8')
otsu_img[otsu_img > keys['pixel']]=1
otsu_img[otsu_img < keys['pixel']]=0
#print(otsu_img.dtype)
plt.imshow(otsu_img)
plt.savefig('otsu.png')
plt.show()
Resulting otsu image looks like this:
Here is the fish image (It has a shirtless guy holding a fish so may not be safe for work):
Link : https://i.sstatic.net/EDTem.jpg
EDIT:
It turns out that by changing the threshold to 255 (The differences are more pronounced)
otsu_img[otsu_img > keys['pixel']]=1
andotsu_img[otsu_img < keys['pixel']]=0
. What you're doing here is setting all pixels above your threshold (let's say 130) to 1. Next you're finding all pixels below 130, including those you just set to 1, and setting them to 0. What you've got left is all pixels with a value of exactly 130. The rest is 0. Also, you're doing this on a color image, meaning you are thresholding the three channels separately and re-composing it as an RGB image. Convert to a gray-value image first! – Glasses