Is this the correct way of whitening an image in python?
Asked Answered
A

3

7

I am trying to zero-center and whiten CIFAR10 dataset, but the result I get looks like random noise!
Cifar10 dataset contains 60,000 color images of size 32x32. The training set contains 50,000 and test set contains 10,000 images respectively.
The following snippets of code show the process I did to get the dataset whitened :

# zero-center
mean = np.mean(data_train, axis = (0,2,3)) 
for i in range(data_train.shape[0]):
    for j in range(data_train.shape[1]):
        data_train[i,j,:,:] -= mean[j]

first_dim = data_train.shape[0] #50,000
second_dim = data_train.shape[1] * data_train.shape[2] * data_train.shape[3] # 3*32*32
shape = (first_dim, second_dim) # (50000, 3072) 

# compute the covariance matrix
cov = np.dot(data_train.reshape(shape).T, data_train.reshape(shape)) / data_train.shape[0] 
# compute the SVD factorization of the data covariance matrix
U,S,V = np.linalg.svd(cov)

print 'cov.shape = ',cov.shape
print U.shape, S.shape, V.shape

Xrot = np.dot(data_train.reshape(shape), U) # decorrelate the data
Xwhite = Xrot / np.sqrt(S + 1e-5)

print Xwhite.shape
data_whitened = Xwhite.reshape(-1,32,32,3)
print data_whitened.shape

outputs:

cov.shape =  (3072L, 3072L)
(3072L, 3072L) (3072L,) (3072L, 3072L)
(50000L, 3072L)
(50000L, 32L, 32L, 3L)
(32L, 32L, 3L)

and trying to show the resulting image :

import matplotlib.pyplot as plt
%matplotlib inline
from scipy.misc import imshow
print data_whitened[0].shape
fig = plt.figure()
plt.subplot(221)
plt.imshow(data_whitened[0])
plt.subplot(222)
plt.imshow(data_whitened[100])
plt.show()

enter image description here

By the way the data_train[0].shape is (3,32,32), but if I reshape the whittened image according to that I get

TypeError: Invalid dimensions for image data

Could this be a visualization issue only? if so how can I make sure thats the case?

Update :
Thanks to @AndrasDeak, I fixed the visualization code this way, but still the output looks random :

data_whitened = Xwhite.reshape(-1,3,32,32).transpose(0,2,3,1)
print data_whitened.shape
fig = plt.figure()
plt.subplot(221)
plt.imshow(data_whitened[0])

enter image description here

Update 2:
This is what I get when I run some of the commands given below : As it can be seen below, toimage can show the image just fine, but trying to reshape it, messes up the image. enter image description here

# output is of shape (N, 3, 32, 32)
X = X.reshape((-1,3,32,32))
# output is of shape (N, 32, 32, 3)
X = X.transpose(0,2,3,1)
# put data back into a design matrix (N, 3072)
X = X.reshape(-1, 3072)

plt.imshow(X[6].reshape(32,32,3))
plt.show()

enter image description here

for some wierd reason, this was what I got at first , but then after several tries, it changed to the previous image. enter image description here

Anecdotage answered 13/1, 2017 at 13:28 Comment(14)
I'm not familiar with whitening, but yes, the error you're getting is due to that plt.imshow expects an (M,N,3)-shaped array as an RGB image. But this problem goes deeper: I wouldn't expect your data_train to be shaped (N,3,32,32) either: it should contain a similar pattern of row-column-RGB_channel dimensions. Which suggests that you're possibly misinterpreting the dimensions of your input, which can explain why your output is not what you expect it to be.Athodyd
Oh, and unless I'm mistaken, the zero-centering you're doing is equivalent to the vectorized data_train -= np.mean(data_train, axis = (0,2,3))[:,None,None], making use of array broadcasting.Athodyd
Last comment: I'd expect zero-centering to work image-by-image. You center each colour channel of each image. This would mean (in case the final 2 dimensions of data_train correspond to pixels) that you need np.mean(data_train,axis=(2,3)), and correspondigly data_train -= np.mean(data_train, axis = (0,2,3))[...,None,None]. Is that not right?Athodyd
Maybe a stupid question, but can't you access the bytes in memory using ctypes and simply overwrite them with (255,255,255) assuming RGB?Precast
@z0rberg's if I understand your suggestion: OP is trying to do whitening.Athodyd
@AndrasDeak: Actually the data_train shape is exactly (50000L, 3L, 32L, 32L) . And yes I deliberately tried to zero-center each channel to see how it affects the overall performance. doing the np.mean(data_train, axis=0) and also your suggestion didn't make any difference in what I get as the final result either. I also tried using scipy.misc's toimage function to display the result, but nothing interesting turned up! the result was the same!Anecdotage
Well, just for visualization, you can avoid that error by using data.whitened[0,...].transpose(1,2,0). Then the RGB dimension is last, and imshow will happily plot it.Athodyd
OK, I think I see (at least one) problem. You're only using reshapes in your code, yet you start from (3,32,32) and end up with (32,32,3). This is wrong. If you reshape your data rather than permuting the indices (with .transpose), you'll get your array elements all mixed up. That's definitely wrong. I'm not sure if that's correct, but you might be looking for data_whitened = Xwhite.reshape(-1,3,32,32).permute(0,2,3,1).Athodyd
Thanks, I see, you are right there, but that snippet you wrote, gives 'AttributeError: 'numpy.ndarray' object has no attribute 'permute'" error!Anecdotage
Sorry, MATLAB habits:) When I wrote .permute, I kept meaning .transpose. Apologies.Athodyd
Thanks, just used it, but still I get the same image, I updated the question with the imageAnecdotage
Well, at least it's not the exact same image, as I see it. Unfortunately, I'm unfamiliar with whitening, so I can't help you with the core of your code. Have you tried plotting some of your training data similarly, to see how that looks like? I vaguely remember a bit weird behaviour of plt.imshow depending on the type of your input data. Are you working with unsigned integers or floats?Athodyd
@AndrasDeak, I understand and I really appreciate your time and kindness so far :) God bless you . I'll see into it and see if I can understand what is wrong with this code :) the images are float, as I can see(data_train) (` [[[ -71.71074 -87.14036 -81.05044 ...,]]])` and data_whitened[0] is ([[[ 0.86028489, -0.85494366, 0.8545953 ],...Anecdotage
@AndrasDeak oh. I'm so dumb...Precast
V
17

Let's walk through this. As you point out, CIFAR contains images which are stored in a matrix; each image is a row, and each row has 3072 columns of uint8 numbers (0-255). Images are 32x32 pixels and pixels are RGB (three channel colour).

# https://www.cs.toronto.edu/~kriz/cifar.html
# wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
# tar xf cifar-10-python.tar.gz
import numpy as np
import cPickle
with open('cifar-10-batches-py/data_batch_1') as input_file: 
    X = cPickle.load(input_file)
X = X['data']   # shape is (N, 3072)

It turns out that the columns are ordered a bit funny: all the red pixel values come first, then all the green pixels, then all the blue pixels. This makes it tricky to have a look at the images. This:

import matplotlib.pyplot as plt
plt.imshow(X[6].reshape(32,32,3))
plt.show()

gives this:

Mixed up colour channels

So, just for ease of viewing, let's shuffle the dimensions of our matrix around with reshape and transpose:

# output is of shape (N, 3, 32, 32)
X = X.reshape((-1,3,32,32))
# output is of shape (N, 32, 32, 3)
X = X.transpose(0,2,3,1)
# put data back into a design matrix (N, 3072)
X = X.reshape(-1, 3072)

Now:

plt.imshow(X[6].reshape(32,32,3))
plt.show()

gives:

A peacock

OK, on to ZCA whitening. We're frequently reminded that it's super important to zero-center the data before whitening it. At this point, an observation about the code you include. From what I can tell, computer vision views color channels as just another feature dimension; there's nothing special about the separate RGB values in an image, just like there's nothing special about the separate pixel values. They're all just numeric features. So, whereas you're computing the average pixel value, respecting colour channels (i.e., your mean is a tuple of r,g,b values), we'll just compute the average image value. Note that X is a big matrix with N rows and 3072 columns. We'll treat every column as being "the same kind of thing" as every other column.

# zero-centre the data (this calculates the mean separately across
# pixels and colour channels)
X = X - X.mean(axis=0)

At this point, let's also do Global Contrast Normalization, which is quite often applied to image data. I'll use the L2 norm, which makes every image have vector magnitude 1:

X = X / np.sqrt((X ** 2).sum(axis=1))[:,None]

One could easily use something else, like the standard deviation (X = X / np.std(X, axis=0)) or min-max scaling to some interval like [-1,1].

Nearly there. At this point, we haven't greatly modified our data, since we've just shifted and scaled it (a linear transform). To display it, we need to get image data back into the range [0,1], so let's use a helper function:

def show(i):
    i = i.reshape((32,32,3))
    m,M = i.min(), i.max()
    plt.imshow((i - m) / (M - m))
    plt.show()

show(X[6])

The peacock looks slightly brighter here, but that's just because we've stretched its pixel values to fill the interval [0,1]:

Slightly brighter peacock

ZCA whitening:

# compute the covariance of the image data
cov = np.cov(X, rowvar=True)   # cov is (N, N)
# singular value decomposition
U,S,V = np.linalg.svd(cov)     # U is (N, N), S is (N,)
# build the ZCA matrix
epsilon = 1e-5
zca_matrix = np.dot(U, np.dot(np.diag(1.0/np.sqrt(S + epsilon)), U.T))
# transform the image data       zca_matrix is (N,N)
zca = np.dot(zca_matrix, X)    # zca is (N, 3072)

Taking a look (show(zca[6])):

"Whitened" peacock

Now the peacock definitely looks different. You can see that the ZCA has rotated the image through colour space, so it looks like a picture on an old TV with the Tone setting out of whack. Still recognisable, though.

Presumably because of the epsilon value I used, the covariance of my transformed data isn't exactly identity, but it's fairly close:

>>> (np.cov(zca, rowvar=True).argmax(axis=1) == np.arange(zca.shape[0])).all()
True

Update 29 January

I'm not entirely sure how to sort out the issues you're having; your trouble seems to lie in the shape of your raw data at the moment, so I would advise you to sort that out first before you try to move on to zero-centring and ZCA.

One the one hand, the first plot of the four plots in your update looks good, suggesting that you've loaded up the CIFAR data in the correct way. The second plot is produced by toimage, I think, which will automagically figure out which dimension has the colour data, which is a nice trick. On the other hand, the stuff that comes after that looks weird, so it seems something is going wrong somewhere. I confess I can't quite follow the state of your script, because I suspect you're working interactively (notebook), retrying things when they don't work (more on this in a second), and that you're using code that you haven't shown in your question. In particular, I'm not sure how you're loading the CIFAR data; your screenshot shows output from some print statements (Reading training data..., etc.), and then when you copy train_data into X and print the shape of X, the shape has already been reshaped into (N, 3, 32, 32). Like I say, Update plot 1 would tend to suggest that the reshape has happened correctly. From plots 3 and 4, I think you're getting mixed up about matrix dimensions somewhere, so I'm not sure how you're doing the reshape and transpose.

Note that it's important to be careful with the reshape and transpose, for the following reason. The X = X.reshape(...) and X = X.transpose(...) code is modifying the matrix in place. If you do this multiple times (like by accident in the jupyter notebook), you will shuffle the axes of your matrix over and over, and plotting the data will start to look really weird. This image shows the progression, as we iterate the reshape and transpose operations:

Increasing iterations of reshape and transpose

This progression does not cycle back, or at least, it doesn't cycle quickly. Because of periodic regularities in the data (like the 32-pixel row structure of the images), you tend to get banding in these improperly reshape-transposed images. I'm wondering if that's what's going on in the third of your four plots in your update, which looks a lot less random than the images in the original version of your question.

The fourth plot of your update is a colour negative of the peacock. I'm not sure how you're getting that, but I can reproduce your output with:

plt.imshow(255 - X[6].reshape(32,32,3))
plt.show()

which gives:

Colour negative of the peacock

One way you could get this is if you were using my show helper function, and you mixed up m and M, like this:

def show(i):
    i = i.reshape((32,32,3))
    m,M = i.min(), i.max()
    plt.imshow((i - M) / (m - M))  # this will produce a negative img
    plt.show()
Vitta answered 27/1, 2017 at 12:46 Comment(12)
Thank you very much, But I get completely different result from you. I updated the question with 3 images which show how they lookAnecdotage
@Anecdotage can you also include your toimage function in the question? What you're getting in In[59] looks OK; the image under In[60] looks wrong somehow, though (could it be a negative?).Vitta
toimage is one of the scipy module's functions check this out : docs.scipy.org/doc/scipy-0.18.1/reference/generated/… matplotlib fails to show the image and complains about axis (I dont remember the exact error, thats why I used your function to show how it looks when used with matplotlib) In[60] I'm only showing the same image, I have no idea what is happening there and why I am not getting the correct output!Anecdotage
@Anecdotage new section added to address some of the points in your update.Vitta
@ wildwilhelm: thank you very much, very well said. Yes I had read Cifar10 in advance and converted it into lmdb database format for using it in Caffe (using C++ and not python!), and now after your new update things make sense . I give you the whole points since you deserve it and if any (minor) issues happen, I simply comment here and discuss it with you, or if it is a major one a new question will be asked . Thank you very much for your thorough and well explained answer :)Anecdotage
@wildwihelm: Do we also whitten the test set ? if so, should we just repeat the process for test set, or should I be using the mean/std from training data and then go on?Anecdotage
Correct, you should apply the same transformations to the test set. This means saving the mean value of the training data, and the zca_matrix. For contrast normalization, if you're doing that, and if you're doing the L2 like I did, you can just L2-normalize the test set too (i.e., there's no parameter to learn from the training data for L2-normalization). If you're using stdev-normalization or something else, you should also estimate that on training data and then apply the same estimate to the test data.Vitta
Thanks, Are you sure about the zca_matrix? since zca_matrixs shape wont be the same as the test_data, zca_matrix for training_data is much larger than the test data. would you elaborate more on this?Anecdotage
Hm. Maybe it's a better idea to do cov = np.cov(X, rowvar=False) and zca = np.dot(X, zca_matrix). Then zca_matrix would be of shape (3072, 3072) and could be used on the test set too. pylearn2 has a working implementation of ZCA whitening, and I think they do it that way.Vitta
Thanks, but whats the difference between the two ? I mean setting rowvar to False vs True?Anecdotage
rowvar=True calculates the covariance between training samples; rowvar=False calculates the covariance between dimensions.Vitta
This line looks suspicious: X = X / np.sqrt((X ** 2).sum(axis=1))[:,None] Why normalize each image column separately rather then the whole image?Burlburlap
S
2

I had the same issue: the resulting projected values are off:

A float image is supposed to be in [0-1.0] values for each

def toimage(data):
    min_ = np.min(data)
    max_ = np.max(data)
    return (data-min_)/(max_ - min_)

NOTICE: use this function only for visualization!

However notice how the "decorrelation" or "whitening" matrix is computed @wildwilhelm

zca_matrix = np.dot(U, np.dot(np.diag(1.0/np.sqrt(S + epsilon)), U.T))

This is because the U matrix of eigen vectors of the correlation matrix it's actually this one: SVD(X) = U,S,V but U is the EigenBase of X*X not of X https://en.wikipedia.org/wiki/Singular-value_decomposition

As a final note, I would rather consider statistical units only the pixels and the RGB channels their modalities instead of Images as statistical units and pixels as modalities. I've tryed this on the CIFAR 10 database and it works quite nicely.

IMAGE EXAMPLE: Top image has RGB values "withened", Bottom is the original

Image1

IMAGE EXAMPLE2: NO ZCA transform performances in train and loss

Image2

IMAGE EXAMPLE3: ZCA transform performances in train and loss

Image1

Stithy answered 13/11, 2017 at 1:0 Comment(2)
Would you be a bit more elaborate? whats the problem here? I'm confused!Anecdotage
can you share your whitening script in case it's any different than mine ( paste.ee/p/WSHuv#s=0 )? or the whitened cifar10 dataset itself? I keep getting memory issues.Anecdotage
Z
1

If you want to linearly scale the image to have zero mean and unit norm you can do the same image whitening as Tensofrlow's tf.image.per_image_standardization . After the documentation you need to use the following formula to normalize each image independently:

(image - image_mean) / max(image_stddev, 1.0/sqrt(image_num_elements))

Keep in mind that the mean and the standard deviation should be computed over all values in the image. This means that we don't need to specify the axis/axes along which they are computed.

The way to implement that without Tensorflow is by using numpy as following:

import math
import numpy as np
from PIL import Image

# open image
image = Image.open("your_image.jpg")
image = np.array(image)

# standardize image
mean = image.mean()
stddev = image.std()
adjusted_stddev = max(stddev, 1.0/math.sqrt(image.size))
standardized_image = (image - mean) / adjusted_stddev
Zwinglian answered 2/11, 2018 at 9:42 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.