I'm having an issue plotting an RGB image using Python's Basemap module with latitude and longitude data. Now, I am able to make the plots that I want, but the problem is how slow it is, since it is able to plot single channel data much faster than the RGB data, and in general, plotting RGB images on their own is also fast. Since I have lat/lon data, that is where things get complicated. I've checked out the solution to this problem:
How to plot an irregular spaced RGB image using python and basemap?
which is how I got to where I am right now. It essentially comes down to the following issue. When using the pcolormesh
method in basemap, to plot RGB data you have to define a colorTuple parameter which will map the RGB data point by point. Since the array size is on the order of 2000x1000, this takes awhile to do. A snippet of what I'm talking about is seen below (full working code further down):
if one_channel:
m.pcolormesh(lons, lats, img[:,:,0], latlon=True)
else:
# This is the part that is slow, but I don't know how to
# accurately plot the data otherwise.
mesh_rgb = img[:, :-1, :]
colorTuple = mesh_rgb.reshape((mesh_rgb.shape[0] * mesh_rgb.shape[1]), 3)
# What you put in for the image doesn't matter because of the color mapping
m.pcolormesh(lons, lats, img[:,:,0], latlon=True,color=colorTuple)
When plotting just one channel, it can make the map in about 10 seconds or so. When plotting the RGB data, it can take 3-4 minutes. Given that there is only 3 times as much data, I feel that there must be a better way, especially since plotting RGB data can go just as fast as one channel data when you are making rectangular images.
So, my questions is: Is there any way to make this calculation faster, either with other plotting modules (Bokeh for instance) or by changing the color mapping in any way? I've tried using imshow
with carefully chosen map boundaries, but since it just stretches the image to the full extent of the map, this isn't really good enough for accurate mapping of the data.
Below is a stripped down version of my code that will work for an example with the correct modules:
from pyhdf.SD import SD,SDC
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
def get_hdf_attr(infile,dataset,attr):
f = SD(infile,SDC.READ)
data = f.select(dataset)
index = data.attr(attr).index()
attr_out = data.attr(index).get()
f.end()
return attr_out
def get_hdf_dataset(infile,dataset):
f = SD(infile,SDC.READ)
data = f.select(dataset)[:]
f.end()
return data
class make_rgb:
def __init__(self,file_name):
sds_250 = get_hdf_dataset(file_name, 'EV_250_Aggr1km_RefSB')
scales_250 = get_hdf_attr(file_name, 'EV_250_Aggr1km_RefSB', 'reflectance_scales')
offsets_250 = get_hdf_attr(file_name, 'EV_250_Aggr1km_RefSB', 'reflectance_offsets')
sds_500 = get_hdf_dataset(file_name, 'EV_500_Aggr1km_RefSB')
scales_500 = get_hdf_attr(file_name, 'EV_500_Aggr1km_RefSB', 'reflectance_scales')
offsets_500 = get_hdf_attr(file_name, 'EV_500_Aggr1km_RefSB', 'reflectance_offsets')
data_shape = sds_250.shape
along_track = data_shape[1]
cross_track = data_shape[2]
rgb = np.zeros((along_track, cross_track, 3))
rgb[:, :, 0] = (sds_250[0, :, :] - offsets_250[0]) * scales_250[0]
rgb[:, :, 1] = (sds_500[1, :, :] - offsets_500[1]) * scales_500[1]
rgb[:, :, 2] = (sds_500[0, :, :] - offsets_500[0]) * scales_500[0]
rgb[rgb > 1] = 1.0
rgb[rgb < 0] = 0.0
lin = np.array([0, 30, 60, 120, 190, 255]) / 255.0
nonlin = np.array([0, 110, 160, 210, 240, 255]) / 255.0
scale = interp1d(lin, nonlin, kind='quadratic')
self.img = scale(rgb)
def plot_image(self):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
ax.set_yticks([])
ax.set_xticks([])
plt.imshow(self.img, interpolation='nearest')
plt.show()
def plot_geo(self,geo_file,one_channel=False):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
lats = get_hdf_dataset(geo_file, 0)
lons = get_hdf_dataset(geo_file, 1)
lat_0 = np.mean(lats)
lat_range = [np.min(lats), np.max(lats)]
lon_0 = np.mean(lons)
lon_range = [np.min(lons), np.max(lons)]
map_kwargs = dict(projection='cass', resolution='l',
llcrnrlat=lat_range[0], urcrnrlat=lat_range[1],
llcrnrlon=lon_range[0], urcrnrlon=lon_range[1],
lat_0=lat_0, lon_0=lon_0)
m = Basemap(**map_kwargs)
if one_channel:
m.pcolormesh(lons, lats, self.img[:,:,0], latlon=True)
else:
# This is the part that is slow, but I don't know how to
# accurately plot the data otherwise.
mesh_rgb = self.img[:, :-1, :]
colorTuple = mesh_rgb.reshape((mesh_rgb.shape[0] * mesh_rgb.shape[1]), 3)
m.pcolormesh(lons, lats, self.img[:,:,0], latlon=True,color=colorTuple)
m.drawcoastlines()
m.drawcountries()
plt.show()
if __name__ == '__main__':
# https://ladsweb.nascom.nasa.gov/archive/allData/6/MOD021KM/2015/183/
data_file = 'MOD021KM.A2015183.1005.006.2015183195350.hdf'
# https://ladsweb.nascom.nasa.gov/archive/allData/6/MOD03/2015/183/
geo_file = 'MOD03.A2015183.1005.006.2015183192656.hdf'
# Very Fast
make_rgb(data_file).plot_image()
# Also Fast, takes about 10 seconds
make_rgb(data_file).plot_geo(geo_file,one_channel=True)
# Much slower, takes several minutes
make_rgb(data_file).plot_geo(geo_file)