How to speed up a Python Basemap choropleth animation
Asked Answered
T

0

7

Taking ideas from various sources, and combining with my own, I sought to create an animated maps showing the shading of countries based on some value in my data.

The basic process is this:

  1. Run DB query to get dataset, keyed by country and time
  2. Use pandas to do some data manipulation (sums, avgs, etc)
  3. Initialize basemap object, then load Load external shapefile
  4. Using the animation library, color the countries, one frame for each distinct "time" in the dataset.
  5. Save as gif or mp4 or whatever

This works just fine. The problem is that it is extremely slow. I have potentially over 100k time intervals (over several metrics) I want to animate, and I'm getting an average time of 15s to generate each frame, and it gets worse the more frames there are. At this rate, it will potentially take weeks of maxing out the cpu and memory on my computer to generate a single animation.

I know that matplotlib isn't known for being very fast (examples: 1 and 2) But I read stories of people generating animations at 5+ fps and wonder what I'm doing wrong.

Some optimizations that I have done:

  1. Only recolor the countries in the animate function. This takes on average ~3s per frame, so while it could be improved, it's not what takes the most time.
  2. I use the blit option.
  3. I tried using a smaller plots size and less detailed basemap, but the results were marginal.

Perhaps a less detailed shapefile would speed up the coloring of the shapes, but as I said before, that's only a 3s per frame improvement.

Here is the code (minus a few identifiable features)

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time

from math import pi
from sqlalchemy import create_engine
from mpl_toolkits.basemap import Basemap
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
from geonamescache import GeonamesCache
from datetime import datetime


def get_dataset(avg_interval, startTime, endTime):
    ### SQL query
    # Returns a dataframe with fields [country, unixtime, metric1, metric2, metric3, metric4, metric5]]
    # I use unixtime so I can group by any arbitrary interval to get sums and avgs of the metrics (hence the param avg_interval)
    return df

# Initialize plot figure
fig=plt.figure(figsize=(11, 6))
ax = fig.add_subplot(111, axisbg='w', frame_on=False)

# Initialize map with Robinson projection
m = Basemap(projection='robin', lon_0=0, resolution='c')
# Load and read shapefile
shapefile = 'countries/ne_10m_admin_0_countries'
m.readshapefile(shapefile, 'units', color='#dddddd', linewidth=0.005)

# Get valid country code list
gc = GeonamesCache()
iso2_codes = list(gc.get_dataset_by_key(gc.get_countries(), 'fips').keys())

# Get dataset and remove invalid countries
# This one will get daily aggregates for the first week of the year
df = get_dataset(60*60*24, '2016-01-01', '2016-01-08')
df.set_index(["country"], inplace=True)
df = df.ix[iso2_codes].dropna()

num_colors = 20

# Get list of distinct times to iterate over in the animation
period = df["unixtime"].sort_values(ascending=True).unique()

# Assign bins to each value in the df
values = df["metric1"]
cm = plt.get_cmap('afmhot_r')
scheme= cm(1.*np.arange(num_colors)/num_colors)
bins = np.linspace(values.min(), values.max(), num_colors)
df["bin"] = np.digitize(values, bins) - 1

# Initialize animation return object
x,y = m([],[])
point = m.plot(x, y,)[0]

# Pre-zip country details and shap objects
zipped = zip(m.units_info, m.units)
tbegin = time.time()

# Animate! This is the part that takes a long time. Most of the time taken seems to happen between frames...
def animate(i):
    # Clear the axis object so it doesn't draw over the old one
    ax.clear()
    # Dynamic title
    fig.suptitle('Num: {}'.format(datetime.utcfromtimestamp(int(i)).strftime('%Y-%m-%d %H:%M:%S')), fontsize=30, y=.95)
    tstart = time.time()

    # Get current frame dataset
    frame = df[df["unixtime"]==i]

    # Loop through every country
    for info, shape in zipped:
        iso2 = info['ISO_A2']
        if iso2 not in frame.index:
            # Gray if not in dataset
            color = '#dddddd'
        else:
            # Colored if in dataset
            color = scheme[int(frame.ix[iso2]["bin"])]

        # Get shape info for country, then color on the ax subplot
        patches = [Polygon(np.array(shape), True)]
        pc = PatchCollection(patches)
        pc.set_facecolor(color)
        ax.add_collection(pc)
    tend = time.time()
    #print "{}%: {} of {} took {}s".format(str(ind/tot*100), str(ind), str(tot), str(tend-tstart))
    print "{}: {}s".format(datetime.utcfromtimestamp(int(i)).strftime('%Y-%m-%d %H:%M:%S'), str(tend-tstart))
    return None

# Initialize animation object
output = animation.FuncAnimation(fig, animate, period, interval=150, repeat=False, blit=False)
filestring = time.strftime("%Y%m%d%H%M%S")
# Save animation object as m,p4
#output.save(filestring + '.mp4', fps=1, codec='ffmpeg', extra_args=['-vcodec', 'libx264'])
# Save animation object as gif
output.save(filestring + '.gif', writer='imagemagick')
tfinish = time.time()

print "Total time: {}s".format(str(tfinish-tbegin))
print "{}s per frame".format(str((tfinish-tbegin)/len(df["unixtime"].unique())))

P.S. I know the code is sloppy and could use some cleanup. I'm open to any suggestions, especially if that cleanup would improve performance!

Edit 1: Here is an example of the output

2016-01-01 00:00:00: 3.87843298912s
2016-01-01 00:00:00: 4.08691620827s
2016-01-02 00:00:00: 3.40868711472s
2016-01-03 00:00:00: 4.21187019348s
Total time: 29.0233821869s
9.67446072896s per frame

The first first few lines represent the date being processed, and the runtime of each frame. I have no clue why the first one is repeated. The final line it the total runtime of the program divided by the number of frames. Note that the average time is 2-3x the individual times. This makes me think that there is something happening "between" the frames that is eating up a lot of time.

Edit 2: I ran some performance tests and determined that average time to generate each additional frame is greater than the last, proportional to the number of frames, indicating that this is an quadratic-time process. (or would it be exponential?) Either way, I'm very confused why this wouldn't be linear. If the dataset is already generated, and the maps take a constant time to regenerate, what variable is causing each extra frame to take longer than the previous?

Edit 3: I just made the realization that I have no idea how the animation function works. The (x,y) and point variables were taken from an example that was just plotting moving dots, so it makes sense in that context. A map... not so much. I tried returning something map related from the animate function and got better performance. Returning the ax object (return ax,) makes the procedure run in linear time... but doesn't write anything to the gif. Anybody have any idea what I need to return from the animate function to make this work?

Edit 4: Clearing the axis every frame lets the frames generate at a constant rate! Now I just have to work on general optimizations. I'll start with ImportanceOfBeingErnest's suggestion first. The previous edits are obsolete now.

Toweling answered 11/1, 2017 at 0:27 Comment(7)
codereview.stackexchange.com perhaps?Prothrombin
Possibly. Though I wasn't sure whether my performance issues are due to coding inefficiencies, or some feature/trick that I am missing in matplotlib. So I put it here to start.Toweling
It seems to me that in each timestep you add the complete PatchCollection to the axes. Is this the most time consuming step? I am not sure if I understood the problem correctly, but it seems that the shapes stay the same over the whole animation (countries do not change their borders, right?). So wouldn't it be better to add the PatchCollection once and only change its colors within the animation?Dionnadionne
Good idea! I'll see if I can make that work.Toweling
I've tried to explain my problem better in the edit, where I show sample output and times. To summarize, each animate function call takes ~4s, while the total runtime gets worse in a nonlinear fashion, as if there is something else that is being done between animate() calls.Toweling
Ignoring the first datapoint, the graph you show looks rather linear to me, as would be expected, because you constantly add the same PathCollection, thereby linearly increasing the number of points that need to be drawn to the canvas for each time step. Concerning animate's return value, this should be none for blitting=False and a list of artists to redraw if blitting is used. Did you consider my suggestion from above? Also, would you be able to turn your code into a runnable example (maybe there are some shapefiles available online?) so one could actually test it?Dionnadionne
Hmm, I suppose that it is linear from that point of view. I figured that each frame would be a separate entity and have a constant number of points to draw. Not so if patches are being redrawn each frame. That means frame one would have n points, frame 2 2n, frame 3 3n, and so on... I did consider your pointer above, but have yet to implement a runnable example of it. I'll work on integrating with a public dataset then update the post. Thanks again!Toweling

© 2022 - 2024 — McMap. All rights reserved.