Why does using X[0] in MNIST classifier code give me an error?
Asked Answered
E

5

11

I was learning to do classification with the MNIST dataset. And I got an error which I am not able to figure out, I have done a lot of google searches and I am not able to do anything, maybe you are an expert and can help me. Here is the code--

>>> from sklearn.datasets import fetch_openml
>>> mnist = fetch_openml('mnist_784', version=1)
>>> mnist.keys()

output: dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])

>>> X, y = mnist["data"], mnist["target"]
>>> X.shape

output:(70000, 784)

>>> y.shape

output:(70000)

>>> X[0]

output:KeyError                                  Traceback (most recent call last)
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2897             try:
-> 2898                 return self._engine.get_loc(casted_key)
   2899             except KeyError as err:

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 0

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
<ipython-input-10-19c40ecbd036> in <module>
----> 1 X[0]

c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
   2904             if self.columns.nlevels > 1:
   2905                 return self._getitem_multilevel(key)
-> 2906             indexer = self.columns.get_loc(key)
   2907             if is_integer(indexer):
   2908                 indexer = [indexer]

c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2898                 return self._engine.get_loc(casted_key)
   2899             except KeyError as err:
-> 2900                 raise KeyError(key) from err
   2901 
   2902         if tolerance is not None:

KeyError: 0
Enos answered 30/12, 2020 at 11:19 Comment(1)
we usually get KeyError while accessing dictionaries, this seems X is a dictionary, try printing X and see what it containsStcyr
C
14

The API of fetch_openml changed between versions. In earlier versions, it returns a numpy.ndarray array. Since 0.24.0 (December 2020), as_frame argument of fetch_openml is set to auto (instead of False as default option earlier) which gives you a pandas.DataFrame for the MNIST data. You can force the data read as a numpy.ndarray by setting as_frame = False. See fetch_openml reference .

Canonicals answered 30/12, 2020 at 11:59 Comment(1)
Ya, I am using an older numpy version because using the newer version is giving me an runtime error. By the way Thank you so much.Enos
K
12

I was also facing the same problem.

  • scikit-learn: 0.24.0
  • matplotlib: 3.3.3
  • Python: 3.9.1

I used to below code to resolve the issue.

import matplotlib as mpl
import matplotlib.pyplot as plt


# instead of some_digit = X[0]
some_digit = X.to_numpy()[0]
some_digit_image = some_digit.reshape(28,28)

plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()
Kuhns answered 17/1, 2021 at 6:7 Comment(2)
Great, Are you using the book "Hands on Machine learning with scikit-learn tensorflow and Keras"? I am using the same.Enos
yes, I have just started with this book. Also can you accept the answer if it resolve the issue.Kuhns
C
10

You don't need to downgrade you scikit-learn library, if you follow the code below:

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version= 1, as_frame= False)
mnist.keys()
Cherimoya answered 3/2, 2021 at 20:59 Comment(0)
M
1

You load the dataset as a dataframe for you to able to access the images, you have two ways to do this,

Transform the dataframe to an Array

# Transform the dataframe into an array. Check the first value
some_digit = X.to_numpy()[0]

# Reshape it to (28,28). Note: 28 x 28 = 7064, if the reshaping doesn't meet 
# this you are not able to show the image
some_digit_image = some_digit.reshape(28,28)

plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()

Transform the row

# Transform the row of your choosing into an array
some_digit = X.iloc[0,:].values

# Reshape it to (28,28). Note: 28 x 28 = 7064, if the reshaping doesn't 
# meet this you are not able to show the image
some_digit_image = some_digit.reshape(28,28)

plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()
Moller answered 31/8, 2022 at 16:46 Comment(0)
C
1

You can add parser = 'auto' as an extra parameter while loading the mnist dataset.

I have imported this way:

mnist = fetch_openml('mnist_784', version= 1, as_frame= False, parser='auto')
Crimp answered 12/5, 2023 at 20:11 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.