How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib?
Asked Answered
G

2

8

I am new to Pytorch. I have been trying to learn how to view my input images before I begin training on my CNN. I am having a very hard time changing the images into a form that can be used with matplotlib.

So far I have tried this:

from multiprocessing import freeze_support

import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam

import matplotlib.pyplot as plt
import numpy as np
import PIL

num_classes = 5
batch_size = 100
num_of_workers = 5

DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'
DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'

trans = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToPImage(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
    ])

train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    print(npimg)
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))

def main():
    # get some random training images
    dataiter = iter(train_loader)
    images, labels = dataiter.next()

    # show images
    imshow(images)
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

if __name__ == "__main__":
    main()

However, this throws and error:

  [[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353
    0.27450982]
   [0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646
    0.40784314]
   [0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354
    0.34509805]
   ...
   [0.2784314  0.21960783 0.2352941  ... 0.5803922  0.46666667
    0.25882354]
   [0.26666668 0.16862744 0.23137254 ... 0.2901961  0.29803923
    0.2509804 ]
   [0.30980393 0.39607844 0.28627452 ... 0.1490196  0.10588235
    0.19607842]]

  [[0.2352941  0.06274509 0.15686274 ... 0.09411764 0.3019608
    0.19215685]
   [0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235
    0.3019608 ]
   [0.20392156 0.13333333 0.1607843  ... 0.16862744 0.2117647
    0.22745097]
   ...
   [0.18039215 0.16862744 0.1490196  ... 0.45882353 0.36078432
    0.16470587]
   [0.1607843  0.10588235 0.14117646 ... 0.2117647  0.18039215
    0.10980392]
   [0.18039215 0.3019608  0.2117647  ... 0.11372548 0.06274509
    0.04705882]]]


 ...


 [[[0.8980392  0.8784314  0.8509804  ... 0.627451   0.627451
    0.627451  ]
   [0.8509804  0.8235294  0.7921569  ... 0.54901963 0.5568628
    0.56078434]
   [0.7921569  0.7529412  0.7176471  ... 0.47058824 0.48235294
    0.49411765]
   ...
   [0.3764706  0.38431373 0.3764706  ... 0.4509804  0.43137255
    0.39607844]
   [0.38431373 0.39607844 0.3882353  ... 0.4509804  0.43137255
    0.39607844]
   [0.3882353  0.4        0.39607844 ... 0.44313726 0.42352942
    0.39215687]]

  [[0.9254902  0.90588236 0.88235295 ... 0.60784316 0.6
    0.5921569 ]
   [0.88235295 0.85490197 0.8235294  ... 0.5411765  0.5372549
    0.53333336]
   [0.8235294  0.7882353  0.75686276 ... 0.47058824 0.47058824
    0.47058824]
   ...
   [0.50980395 0.5176471  0.5137255  ... 0.58431375 0.5647059
    0.53333336]
   [0.5137255  0.53333336 0.5254902  ... 0.58431375 0.5686275
    0.53333336]
   [0.5176471  0.53333336 0.5294118  ... 0.5764706  0.56078434
    0.5294118 ]]

  [[0.95686275 0.9372549  0.90588236 ... 0.18823528 0.19999999
    0.20784312]
   [0.9098039  0.8784314  0.8352941  ... 0.1607843  0.17254901
    0.18039215]
   [0.84313726 0.7921569  0.7490196  ... 0.1372549  0.14509803
    0.15294117]
   ...
   [0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764
    0.02745098]
   [0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078
    0.03529412]
   [0.05098039 0.0745098  0.07843137 ... 0.12549019 0.10196078
    0.04705882]]]


 [[[0.30588236 0.28627452 0.24313724 ... 0.2901961  0.26666668
    0.21568626]
   [0.8156863  0.6666667  0.5921569  ... 0.18039215 0.23921567
    0.21568626]
   [0.9019608  0.83137256 0.85490197 ... 0.21960783 0.36862746
    0.23921567]
   ...
   [0.7058824  0.83137256 0.85490197 ... 0.2627451  0.24313724
    0.20784312]
   [0.7137255  0.84313726 0.84705883 ... 0.26666668 0.29803923
    0.21568626]
   [0.7254902  0.8235294  0.8392157  ... 0.2509804  0.27058825
    0.2352941 ]]

  [[0.24705881 0.22745097 0.19215685 ... 0.2784314  0.25490198
    0.19607842]
   [0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097
    0.20392156]
   [0.5921569  0.4509804  0.49803922 ... 0.20784312 0.3764706
    0.2352941 ]
   ...
   [0.42352942 0.4627451  0.42352942 ... 0.23921567 0.23137254
    0.19999999]
   [0.45882353 0.5176471  0.35686275 ... 0.23921567 0.26666668
    0.19607842]
   [0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567
    0.21568626]]

  [[0.23137254 0.20784312 0.1490196  ... 0.30588236 0.28627452
    0.19607842]
   [0.61960787 0.3764706  0.26666668 ... 0.16470587 0.24313724
    0.21568626]
   [0.57254905 0.43137255 0.48235294 ... 0.2235294  0.40392157
    0.25882354]
   ...
   [0.4        0.42352942 0.37254903 ... 0.25490198 0.24705881
    0.21568626]
   [0.43137255 0.4509804  0.29411766 ... 0.25882354 0.28235295
    0.20392156]
   [0.38431373 0.3529412  0.25490198 ... 0.2352941  0.25490198
    0.23137254]]]


 [[[0.06274509 0.09019607 0.11372548 ... 0.5803922  0.5176471
    0.59607846]
   [0.09411764 0.14509803 0.1372549  ... 0.5294118  0.49803922
    0.5058824 ]
   [0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942
    0.38431373]
   ...
   [0.15294117 0.12941176 0.1607843  ... 0.85882354 0.8509804
    0.80784315]
   [0.14509803 0.10588235 0.1607843  ... 0.8666667  0.85882354
    0.8       ]
   [0.1490196  0.10588235 0.16470587 ... 0.827451   0.8156863
    0.7921569 ]]

  [[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412
    0.6039216 ]
   [0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902
    0.5372549 ]
   [0.03921568 0.0745098  0.09803921 ... 0.48235294 0.4392157
    0.4117647 ]
   ...
   [0.2117647  0.14509803 0.2784314  ... 0.43137255 0.3529412
    0.34117648]
   [0.2235294  0.11372548 0.2509804  ... 0.4509804  0.39607844
    0.2509804 ]
   [0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432
    0.3254902 ]]

  [[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216
    0.45490196]
   [0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628
    0.3882353 ]
   [0.03921568 0.06666666 0.0862745  ... 0.3764706  0.33333334
    0.28235295]
   ...
   [0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098
    0.09411764]
   [0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371
    0.05490196]
   [0.12156862 0.11372548 0.17254901 ... 0.2352941  0.17254901
    0.14117646]]]]
Traceback (most recent call last):
  File "image_loader.py", line 51, in <module>
    main()
  File "image_loader.py", line 46, in main
    imshow(images)
  File "image_loader.py", line 38, in imshow
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 598, in transpose
    return _wrapfunc(a, 'transpose', axes)
  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 51, in _wrapfunc
    return getattr(obj, method)(*args, **kwds)
ValueError: repeated axis in transpose

I tried to print out the arrays to get the dimensions but I do not know what to make of this. It is very confusing.

Here is my direct question: How do I view the input images before training using the tensors in my DataLoader object?

Gasolier answered 8/8, 2018 at 22:35 Comment(0)
E
19

First of all, dataloader output 4 dimensional tensor - [batch, channel, height, width]. Matplotlib and other image processing libraries often requires [height, width, channel]. You are right about using the transpose, just not in the right way.

There will be a lot of images in your images so first you need to pick one (or write a for loop to save all of them). This will be simply images[i], typically I use i=0.

Then, your transpose should convert a now [channel, height, width] tensor to a [height, width, channel] one. To do this, use np.transpose(image.numpy(), (1, 2, 0)), very much like yours.

Putting them together, you should have

plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))

Sometimes you need to call .detach() (detach this part from the computational graph) and .cpu() (transfer data from GPU to CPU) depending on the use case, that will be

plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))
Exudate answered 9/8, 2018 at 2:36 Comment(7)
Wow, thank you so much for this. You really made my day you know !!! I can now see the images after the transforms have been applied. If you do not mind me asking, how would I go about seeing the predictions that the CNN made on each image. Right now I see accuracy and epochs.Gasolier
@Gasolier Glad it helped. What are your outputs? Numbers? Images?Exudate
So when I checked my outputs it gives a tensor object with numbers from 0-4. I have 5 classes of flowers so I am guessing 0 = 1, 1 = 2 and so on. am trying to train the model and then give it a fresh input to see if it is actually working.Gasolier
@Gasolier If you want to test a single image and view its result, then just print it out, like print(result). If there are a lot of images, you can calculate loss/accuracy.Exudate
I should have clarified more. Right now my output is like this: Epoch {epoch} Train acc {accuracy} Loss {loss} Test acc {test acc}. I am looking to feed it a new image after training and see if it can correctly classify it.Gasolier
@Gasolier Put some of your "new images" in another folder, create a dataset and dataloader for it, and run the network again as in training/testing. You can print the output afterwards. It should be very like your training/testing process. If you encountered any problem implementing that I think you can open another question.Exudate
Thank you very much. You have been very helpful !Gasolier
G
4

This did the trick for me when I faced the same problem. Pytorch dataset behaves similar to a regular list as far as numpy is concerned and hence this works.

train_np = np.array(train_loader.dataset)
Goidelic answered 24/5, 2021 at 9:13 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.