plotting confidence interval for linear regression line of a pandas time-series Dataframe
Asked Answered
T

1

5

I have a sample time-series dataframe:

df = pd.DataFrame({'year':'1990','1991','1992','1993','1994','1995','1996',
                          '1997','1998','1999','2000'],
                   'count':[96,184,148,154,160,149,124,274,322,301,300]})

I want a linear regression line with confidence interval band in the regression line. Although I managed to plot a linear regression line. I am finding it difficult to plot the confidence interval band in the plot. Here is the snippet of my code for linear regression plot:

from matplotlib import ticker
from sklearn.linear_model import LinearRegression



X = df.date_ordinal.values.reshape(-1,1)
y = df['count'].values.reshape(-1, 1)

reg = LinearRegression()

reg.fit(X, y)

predictions = reg.predict(X.reshape(-1, 1))

fig, ax = plt.subplots()

plt.scatter(X, y, color ='blue',alpha=0.5)

plt.plot(X, predictions,alpha=0.5, color = 'black',label = r'$N$'+ '= {:.2f}t + {:.2e}\n'.format(reg.coef_[0][0],reg.intercept_[0]))


plt.ylabel('count($N$)');
plt.xlabel(r'Year(t)');
plt.legend()


formatter = ticker.ScalarFormatter(useMathText=True)
formatter.set_scientific(True) 
formatter.set_powerlimits((-1,1)) 
ax.yaxis.set_major_formatter(formatter)


plt.xticks(ticks = df.date_ordinal[::5], labels = df.index.year[::5])

           


plt.grid()  

plt.show()
plt.clf()

This gives me a nice linear regression plot for time series .

Problem & Desired output However, I need confidence interval for the regression line too as in:.enter image description here

Help on this issue would be highly appreciated.

Trichina answered 28/5, 2021 at 11:4 Comment(2)
Does this helps: #27116979?Damales
This will help: #27164614Damales
F
10

The problem you are running into is that the package and function you use from sklearn.linear_model import LinearRegression does not provide a way to simply obtain the confidence interval.

If you want to absolutely use sklearn.linear_model.LinearRegression, you will have to dive into the methods of calculating a confidence interval. One popular approach is using bootstrapping, like was done with this previous answer.

However, the way I interpret your question, is that you are looking for a way to quickly do this inside of a plot command, similar to the screenshot you attached. If your goal is purely visualization, then you can simply use the seaborn package, which is also where your example image comes from.

import seaborn as sns

sns.lmplot(x='year', y='count', data=df, fit_reg=True, ci=95, n_boot=1000)

Where I have highlighted the three self-explanatory parameters of interest with their default values fit_reg, ci, and n_boot. Refer to the documentation for a full description.

Under the hood, seaborn uses the statsmodels package. So if you want something in between purely visualization, and writing the confidence interval function from scratch yourself, I would refer you instead to using statsmodels. Specifically, look at the documentation for calculating a confidence interval of an ordinary least squares (OLS) linear regression.

The following code should give you a starting point for using statsmodels in your example:

import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt

df = pd.DataFrame({'year':['1990','1991','1992','1993','1994','1995','1996','1997','1998','1999','2000'],
                   'count':[96,184,148,154,160,149,124,274,322,301,300]})
df['year'] = df['year'].astype(float)
X = sm.add_constant(df['year'].values)
ols_model = sm.OLS(df['count'].values, X)
est = ols_model.fit()
out = est.conf_int(alpha=0.05, cols=None)

fig, ax = plt.subplots()
df.plot(x='year',y='count',linestyle='None',marker='s', ax=ax)
y_pred = est.predict(X)
x_pred = df.year.values
ax.plot(x_pred,y_pred)

pred = est.get_prediction(X).summary_frame()
ax.plot(x_pred,pred['mean_ci_lower'],linestyle='--',color='blue')
ax.plot(x_pred,pred['mean_ci_upper'],linestyle='--',color='blue')

# Alternative way to plot
def line(x,b=0,m=1):
    return m*x+b

ax.plot(x_pred,line(x_pred,est.params[0],est.params[1]),color='blue')

Which produces your desired output

While the values of everything are accessible via standard statsmodels functions.

Freeload answered 28/5, 2021 at 15:38 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.