matplotlib pyplot imshow tight spacing between images
Asked Answered
L

1

9

I have some numpy image arrays, all of the same shape (say (64, 64, 3)). I want to plot them in a grid using pyplot.subplot(), but when I do, I get unwanted spacing between images, even when I use pyplot.subplots_adjust(hspace=0, wspace=0). Below is an example piece of code.

from matplotlib import pyplot
import numpy

def create_dummy_images():
    """
    Creates images, each of shape (64, 64, 3) and of dtype 8-bit unsigned integer.

    :return: 4 images in a list.
    """
    saturated_channel = numpy.ones((64, 64), dtype=numpy.uint8) * 255
    zero_channel = numpy.zeros((64, 64), dtype=numpy.uint8)
    red = numpy.array([saturated_channel, zero_channel, zero_channel]).transpose(1, 2, 0)
    green = numpy.array([zero_channel, saturated_channel, zero_channel]).transpose(1, 2, 0)
    blue = numpy.array([zero_channel, zero_channel, saturated_channel]).transpose(1, 2, 0)
    random = numpy.random.randint(0, 256, (64, 64, 3))
    return [red, green, blue, random]


if __name__ == "__main__":
    images = create_dummy_images()
    for i, image in enumerate(images):
        pyplot.subplot(2, 2, i + 1)
        pyplot.axis("off")
        pyplot.imshow(image)
    pyplot.subplots_adjust(hspace=0, wspace=0)
    pyplot.show()

Below is the output.

enter image description here

As you can see, there is unwanted vertical space between those images. One way of circumventing this problem is to carefully hand-pick the right size for the figure, for example I use matplotlib.rcParams['figure.figsize'] = (_, _) in Jupyter Notebook. However, the number of images I usually want to plot varies between each time I plot them, and hand-picking the right figure size each time is extremely inconvenient (especially because I can't work out exactly what the size means in Matplotlib). So, is there a way that Matplotlib can automatically work out what size the figure should be, given my requirement that all my (64 x 64) images need to be flush next to each other? (Or, for that matter, a specified distance next to each other?)

Linnet answered 22/6, 2016 at 16:29 Comment(0)
J
14

NOTE: correct answer is reported in the update below the original answer.


Create your subplots first, then plot in them. I did it on one line here for simplicity sake

images = create_dummy_images()
fig, axs = pyplot.subplots(nrows=1, ncols=4, gridspec_kw={'wspace':0, 'hspace':0},
                           squeeze=True)
for i, image in enumerate(images):
    axs[i].axis("off")
    axs[i].imshow(image)

enter image description here

UPDATE:

Nevermind, the problem was not with your subplot definition, but with imshow() which distorts your axes after you've set them up correctly.

The solution is to use aspect='auto' in the call to imshow() so that the pictures fills the axes without changing them. If you want to have square axes, you need to create a picture with the appropriate width/height ratio:

pyplot.figure(figsize=(5,5))
images = create_dummy_images()

for i, image in enumerate(images):
    pyplot.subplot(2, 2, i + 1)
    pyplot.axis("off")
    pyplot.imshow(image, aspect='auto')

pyplot.subplots_adjust(hspace=0, wspace=0)
pyplot.show()

enter image description here

Jasik answered 23/6, 2016 at 20:45 Comment(3)
Could you demonstrate that this works when the 4 images are in a 2 x 2 grid, as opposed to in a single row? My method also works in a single row, but not in a 2 x 2 grid.Linnet
@Linnet I've amended my answer aboveJasik
In my cases, using plt.subplot(2,2,i+1) sub plot and plt.figure(figsize=(20,20)), plt.subplots_adjust(hspace=0.05, wspace=0.05) will sufficient and work the best if using axis.Jaco

© 2022 - 2024 — McMap. All rights reserved.