I've built upon @DanHickstein's answer to cover cases of plot, scatter and axhline/axvline for scaling either the x or y axis. It can be called as simple as autoscale()
to work on the most recent axes. If you wish to edit it, please fork it on gist.
def autoscale(ax=None, axis='y', margin=0.1):
'''Autoscales the x or y axis of a given matplotlib ax object
to fit the margins set by manually limits of the other axis,
with margins in fraction of the width of the plot
Defaults to current axes object if not specified.
'''
import matplotlib.pyplot as plt
import numpy as np
if ax is None:
ax = plt.gca()
newlow, newhigh = np.inf, -np.inf
for artist in ax.collections + ax.lines:
x,y = get_xy(artist)
if axis == 'y':
setlim = ax.set_ylim
lim = ax.get_xlim()
fixed, dependent = x, y
else:
setlim = ax.set_xlim
lim = ax.get_ylim()
fixed, dependent = y, x
low, high = calculate_new_limit(fixed, dependent, lim)
newlow = low if low < newlow else newlow
newhigh = high if high > newhigh else newhigh
margin = margin*(newhigh - newlow)
setlim(newlow-margin, newhigh+margin)
def calculate_new_limit(fixed, dependent, limit):
'''Calculates the min/max of the dependent axis given
a fixed axis with limits
'''
if len(fixed) > 2:
mask = (fixed>limit[0]) & (fixed < limit[1])
window = dependent[mask]
low, high = window.min(), window.max()
else:
low = dependent[0]
high = dependent[-1]
if low == 0.0 and high == 1.0:
# This is a axhline in the autoscale direction
low = np.inf
high = -np.inf
return low, high
def get_xy(artist):
'''Gets the xy coordinates of a given artist
'''
if "Collection" in str(artist):
x, y = artist.get_offsets().T
elif "Line" in str(artist):
x, y = artist.get_xdata(), artist.get_ydata()
else:
raise ValueError("This type of object isn't implemented yet")
return x, y
It, like its predecessor, is a bit hacky, but that is necessary because collections and lines have different methods for returning the xy coordinates, and because axhline/axvline is tricky to work with since it only has two datapoints.
Here it is in action:
fig, axes = plt.subplots(ncols = 4, figsize=(12,3))
(ax1, ax2, ax3, ax4) = axes
x = np.linspace(0,100,300)
noise = np.random.normal(scale=0.1, size=x.shape)
y = 2*x + 3 + noise
for ax in axes:
ax.plot(x, y)
ax.scatter(x,y, color='red')
ax.axhline(50., ls='--', color='green')
for ax in axes[1:]:
ax.set_xlim(20,21)
ax.set_ylim(40,45)
autoscale(ax3, 'y', margin=0.1)
autoscale(ax4, 'x', margin=0.1)
ax1.set_title('Raw data')
ax2.set_title('Specificed limits')
ax3.set_title('Autoscale y')
ax4.set_title('Autoscale x')
plt.tight_layout()