Why is Pymc3 ADVI worse than MCMC in this logistic regression example?
Asked Answered
V

1

5

I am aware of the mathematical differences between ADVI/MCMC, but I am trying to understand the practical implications of using one or the other. I am running a very simple logistic regressione example on data I created in this way:

import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np

def logistic(x, b, noise=None):
    L = x.T.dot(b)
    if noise is not None:
        L = L+noise
    return 1/(1+np.exp(-L))

x1 = np.linspace(-10., 10, 10000)
x2 = np.linspace(0., 20, 10000)
bias = np.ones(len(x1))
X = np.vstack([x1,x2,bias]) # Add intercept
B =  [-10., 2., 1.] # Sigmoid params for X + intercept

# Noisy mean
pnoisy = logistic(X, B, noise=np.random.normal(loc=0., scale=0., size=len(x1)))
# dichotomize pnoisy -- sample 0/1 with probability pnoisy
y = np.random.binomial(1., pnoisy)

And the I run ADVI like this:

with pm.Model() as model: 
    # Define priors
    intercept = pm.Normal('Intercept', 0, sd=10)
    x1_coef = pm.Normal('x1', 0, sd=10)
    x2_coef = pm.Normal('x2', 0, sd=10)

    # Define likelihood
    likelihood = pm.Bernoulli('y',                  
           pm.math.sigmoid(intercept+x1_coef*X[0]+x2_coef*X[1]),
                          observed=y)
    approx = pm.fit(90000, method='advi')

Unfortunately, no matter how much I increase the sampling, ADVI does not seem to be able to recover the original betas I defined [-10., 2., 1.], while MCMC works fine (as shown below)

enter image description here

Thanks' for the help!

Vernice answered 28/9, 2018 at 15:50 Comment(0)
T
8

This is an interesting question! The default 'advi' in PyMC3 is mean field variational inference, which does not do a great job capturing correlations. It turns out that the model you set up has an interesting correlation structure, which can be seen with this:

import arviz as az

az.plot_pair(trace, figsize=(5, 5))

correlated samples

PyMC3 has a built-in convergence checker - running optimization for to long or too short can lead to funny results:

from pymc3.variational.callbacks import CheckParametersConvergence

with model:
    fit = pm.fit(100_000, method='advi', callbacks=[CheckParametersConvergence()])

draws = fit.sample(2_000)

This stops after about 60,000 iterations for me. Now we can inspect the correlations and see that, as expected, ADVI fit axis-aligned gaussians:

az.plot_pair(draws, figsize=(5, 5))

another correlation image

Finally, we can compare the fit from NUTS and (mean field) ADVI:

az.plot_forest([draws, trace])

forest plot

Note that ADVI is underestimating variance, but fairly close for the mean of each parameter. Also, you can set method='fullrank_advi' to capture the correlations you are seeing a little better.

(note: arviz is soon to be the plotting library for PyMC3)

Taproom answered 29/9, 2018 at 0:57 Comment(6)
Given how widespread correlated features are, isn't the mvnormal with diagonal covariance approximation.....really bad in general?Comrade
totally. you'll find that a lot of the literature on variational inference focuses on this (legitimate!) worry. however, it turns a sampling problem into an optimization problem, which can handle tons of data and goes much faster. So if you don't expect to see correlations, it could be the only feasible approach.Taproom
Right. Anyway, thank you SO MUCH for your answer -- I was seeing poor posterior predictive performance based on ADVI, and I think it may come down to the fact that I have a lot of correlated features, just like OP. I'll try MCMC, and see if that works better.Comrade
btw, is this a problem for most variational inference algorithms, or just advi?Comrade
It depends on the "flavor" of ADVI you use. Mean field uses a diagonal covariance matrix, while full rank fits a dense covariance matrix, which comes with its own problems. See nbviewer.jupyter.org/gist/ColCarroll/… for some comparisons of NUTS, mean field, and full-rank ADVI.Taproom
A lot of variational inference uses a mean field assumption, because in large models working with a full rank covariance matrix is intractable. If you expect correlations, you could roll your own variational solution (assuming it's tractable) or look at for example pyro that allows a more flexible factorisationOutstation

© 2022 - 2024 — McMap. All rights reserved.