Okay, so i did a fairly major reshuffle, i think it separates out the major parts and will make it easy to make modular / in various functions. The orriginal code for the previous answer i gave is here.
Here's the new stuff, hopefully it's pretty self explanatory.
# Setup our various global variables
population_mean = 7
population_std_dev = 1
samples = 100
histogram_bins = 50
# And setup our figure.
from matplotlib import pyplot
fig = pyplot.figure()
ax = fig.add_subplot(1,1,1)
from numpy.random import normal
hist_data = normal(population_mean, population_std_dev, samples)
ax.hist(hist_data, bins=histogram_bins, normed=True, color="blue", alpha=0.3)
from statsmodels.nonparametric.kde import KDEUnivariate
kde = KDEUnivariate(hist_data)
kde.fit()
#kde.supprt and kde.density hold the x and y values of the KDE fit.
ax.plot(kde.support, kde.density, color="red", lw=4)
#Gaussian function - though you can replace this with something of your choosing later.
from numpy import sqrt, exp, pi
r2pi = sqrt(2*pi)
def gaussian(x, mu, sigma):
return exp(-0.5 * ( (x-mu) / sigma)**2) / (sigma * r2pi)
#interpolation of KDE to produce a function.
from scipy.interpolate import interp1d
kde_func = interp1d(kde.support, kde.density, kind="cubic", fill_value=0)
What you want to do is just standard curve fitting - there are numerous ways to do it, and you say you want to fit the curve by maximizing the overlap of the two functions (why?). the curve_fir
scipy routine is a least squares fit, which is trying to minimize the difference between the two functions - the difference is subtle: maximizing the overlap does not punish the fitting function for being larger than the data, whereas curve_fit
does.
I've included solutions using both techniques, as well as profiled them:
#We need to *maximise* the overlap integral
from scipy.integrate import quad as integrate
def overlap(func1, func2, limits, func1_args=[], func2_args=[]):
def product_func(x):
return min(func1(x, *func1_args),func2(x, *func2_args))
return integrate(product_func, *limits)[0] # we only care about the absolute result for now.
limits = hist_data.min(), hist_data.max()
def gaussian_overlap(args):
mu, sigma = args
return -overlap(kde_func, gaussian, limits, func2_args=[mu, sigma])
And now the two different methods, the overlap metric:
import cProfile, pstats, StringIO
pr1 = cProfile.Profile()
pr1.enable()
from scipy.optimize import fmin_powell as minimize
mu_overlap_fit, sigma_overlap_fit = minimize(gaussian_overlap, (population_mean, population_std_dev))
pr1.disable()
s = StringIO.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr1, stream=s).sort_stats(sortby)
ps.print_stats()
print s.getvalue()
3122462 function calls in 6.298 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 6.298 6.298 C:\Python27\lib\site-packages\scipy\optimize\optimize.py:2120(fmin_powell)
1 0.000 0.000 6.298 6.298 C:\Python27\lib\site-packages\scipy\optimize\optimize.py:2237(_minimize_powell)
57 0.000 0.000 6.296 0.110 C:\Python27\lib\site-packages\scipy\optimize\optimize.py:279(function_wrapper)
57 0.000 0.000 6.296 0.110 C:\Users\Will\Documents\Python_scripts\hist_fit.py:47(gaussian_overlap)
57 0.000 0.000 6.296 0.110 C:\Users\Will\Documents\Python_scripts\hist_fit.py:39(overlap)
57 0.000 0.000 6.296 0.110 C:\Python27\lib\site-packages\scipy\integrate\quadpack.py:42(quad)
57 0.000 0.000 6.295 0.110 C:\Python27\lib\site-packages\scipy\integrate\quadpack.py:327(_quad)
57 0.069 0.001 6.295 0.110 {scipy.integrate._quadpack._qagse}
66423 0.154 0.000 6.226 0.000 C:\Users\Will\Documents\Python_scripts\hist_fit.py:41(product_func)
4 0.000 0.000 6.167 1.542 C:\Python27\lib\site-packages\scipy\optimize\optimize.py:2107(_linesearch_powell)
4 0.000 0.000 6.166 1.542 C:\Python27\lib\site-packages\scipy\optimize\optimize.py:1830(brent)
4 0.000 0.000 6.166 1.542 C:\Python27\lib\site-packages\scipy\optimize\optimize.py:1887(_minimize_scalar_brent)
4 0.001 0.000 6.166 1.542 C:\Python27\lib\site-packages\scipy\optimize\optimize.py:1717(optimize)
and the scipy method curve_fit
:
pr2 = cProfile.Profile()
pr2.enable()
from scipy.optimize import curve_fit
(mu_curve_fit, sigma_curve_fit), _ = curve_fit(gaussian, kde.support, kde.density, p0=(population_mean, population_std_dev))
pr2.disable()
s = StringIO.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr2, stream=s).sort_stats(sortby)
ps.print_stats()
print s.getvalue()
122 function calls in 0.001 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.001 0.001 C:\Python27\lib\site-packages\scipy\optimize\minpack.py:452(curve_fit)
1 0.000 0.000 0.001 0.001 C:\Python27\lib\site-packages\scipy\optimize\minpack.py:256(leastsq)
1 0.000 0.000 0.001 0.001 {scipy.optimize._minpack._lmdif}
19 0.000 0.000 0.001 0.000 C:\Python27\lib\site-packages\scipy\optimize\minpack.py:444(_general_function)
19 0.000 0.000 0.000 0.000 C:\Users\Will\Documents\Python_scripts\hist_fit.py:29(gaussian)
1 0.000 0.000 0.000 0.000 C:\Python27\lib\site-packages\scipy\linalg\basic.py:314(inv)
1 0.000 0.000 0.000 0.000 C:\Python27\lib\site-packages\scipy\optimize\minpack.py:18(_check_func)
You can see the curve_fit method is much faster, and the results:
from numpy import linspace
xs = linspace(-1, 1, num=1000) * sigma_overlap_fit * 6 + mu_overlap_fit
ax.plot(xs, gaussian(xs, mu_overlap_fit, sigma_overlap_fit), color="orange", lw=2)
xs = linspace(-1, 1, num=1000) * sigma_curve_fit * 6 + mu_curve_fit
ax.plot(xs, gaussian(xs, mu_curve_fit, sigma_curve_fit), color="purple", lw=2)
pyplot.show()
are very similar. I would recommend curve_fit
. In this case it's 6000x faster. The difference is a little bit more when the underlying data is more complex, but not by much, and you still get a huge speed up. Here's an example for 6 uniformly distributed normal distributions being fit:
Go with curve_fit
!
/Library/Enthought/Canopy_64bit/User/lib/python2.7/site-packages/scipy/integrate/quadrature.py:699: AccuracyWarning: divmax (10) exceeded.
and something does not work when plotting. I agree; it does not look nice with all those globals. Normally I quite like plotting in functions, what do you find problematic about this, and whats the alternative? Also, being new to Python coding (coming from matlab) is there any alternative to defining globals in a case like this? – Blumenfeld