Plotting posterior parameter estimates from multiple models with bayesplot
Asked Answered
S

3

8

I am using the great plotting library bayesplot to visualize posterior probability intervals from models I am estimating with rstanarm. I want to graphically compare draws from different models by getting the posterior intervals for coefficients onto the same plot.

Imagine, for instance, that I have 1000 draws from the posterior for three parameters beta1, beta2, beta3 for two different models:

# load the plotting library
library(bayesplot)
#> This is bayesplot version 1.6.0
#> - Online documentation and vignettes at mc-stan.org/bayesplot
#> - bayesplot theme set to bayesplot::theme_default()
#>    * Does _not_ affect other ggplot2 plots
#>    * See ?bayesplot_theme_set for details on theme setting
library(ggplot2)

# generate fake posterior draws from model1
fdata <- matrix(rnorm(1000 * 3), ncol = 3)
colnames(fdata) <- c('beta1', 'beta2', 'beta3')

# fake posterior draws from model 2
fdata2 <- matrix(rnorm(1000 * 3, 1, 2), ncol = 3)
colnames(fdata2) <- c('beta1', 'beta2', 'beta3')

Bayesplot makes fantastic visualizations for individual model draws, and it is ggplot2 'under the hood' so I can customize as I please:

# a nice plot of 1
color_scheme_set("orange")
mcmc_intervals(fdata) + theme_minimal() + ggtitle("Model 1")

# a nice plot of 2
color_scheme_set("blue")
mcmc_intervals(fdata2) + ggtitle("Model 2")

But what I would like to achieve is to plot these two models together on the same plot, such that for each coefficient I have two intervals and can distinguish which interval is which by mapping color to the model. However I can't figure out how to do this. Some things that don't work:

# doesnt work
mcmc_intervals(fdata) + mcmc_intervals(fdata2)
#> Error: Don't know how to add mcmc_intervals(fdata2) to a plot

# appears to pool
mcmc_intervals(list(fdata, fdata2))

Any ideas on how I could do this? Or how to do it manually given the matrices of posterior draws?

Created on 2018-10-18 by the reprex package (v0.2.1)

Sprout answered 18/10, 2018 at 13:54 Comment(0)
H
3

Just so the answer is also posted here, I have expanded on the code at the link from @Manny T (https://github.com/stan-dev/bayesplot/issues/232)

# simulate having posteriors for two different models each with parameters beta[1],..., beta[4]
posterior_1 <- matrix(rnorm(4000), 1000, 4)
posterior_2 <- matrix(rnorm(4000), 1000, 4)
colnames(posterior_1) <- colnames(posterior_2) <- paste0("beta[", 1:4, "]")

# use bayesplot::mcmc_intervals_data() function to get intervals data in format easy to pass to ggplot
library(bayesplot)
combined <- rbind(mcmc_intervals_data(posterior_1), mcmc_intervals_data(posterior_2))
combined$model <- rep(c("Model 1", "Model 2"), each = ncol(posterior_1))

# make the plot using ggplot 
library(ggplot2)
theme_set(bayesplot::theme_default())
pos <- position_nudge(y = ifelse(combined$model == "Model 2", 0, 0.1))
ggplot(combined, aes(x = m, y = parameter, color = model)) + 
  geom_linerange(aes(xmin = l, xmax = h), position = pos, size=2)+
  geom_linerange(aes(xmin = ll, xmax = hh), position = pos)+
  geom_point(position = pos, color="black")

enter image description here

If you are like me, you will want 80% and 90% credible intervals (instead of 50% being the inner ones) and might want the coordinates flipped, and let's add a dashed line at 0 (model estimates no change). You can do that like this:

# use bayesplot::mcmc_intervals_data() function to get intervals data in format easy to pass to ggplot
library(bayesplot)
combined <- rbind(mcmc_intervals_data(posterior_1,prob=0.8,prob_outer = 0.9), mcmc_intervals_data(posterior_2,prob=0.8,prob_outer = 0.9))
combined$model <- rep(c("Model 1", "Model 2"), each = ncol(posterior_1))

# make the plot using ggplot 
library(ggplot2)
theme_set(bayesplot::theme_default())
pos <- position_nudge(y = ifelse(combined$model == "Model 2", 0, 0.1))
ggplot(combined, aes(x = m, y = parameter, color = model)) + 
  geom_linerange(aes(xmin = l, xmax = h), position = pos, size=2)+
  geom_linerange(aes(xmin = ll, xmax = hh), position = pos)+
  geom_point(position = pos, color="black")+
  coord_flip()+
  geom_vline(xintercept=0,linetype="dashed")

enter image description here

A few things to note on this last one. I added prob_outer = 0.9 even though that is the default, just to show how you might change the outer credible intervals. The dashed line is created with geom_vline and xintercept = here instead of geom_hline and yintercept = because of the coord_flip (everything is reversed). So if you don't flip axes, you will need to do the opposite.

Henshaw answered 11/12, 2020 at 17:51 Comment(0)
V
1

I asked this question on the bayesplot page on GitHub and got a response (Issue #232).

Viniculture answered 25/7, 2020 at 22:11 Comment(2)
HI, welcome to SO. Nice followup to stan/dev. You might edit and put the suggested solution here as well, keeping the issue attribution.Inaccurate
This is great. A bunch of people seem to have looked at this question over the years. Why don't you paste the solution in and I'll accept the answer?Sprout
R
1

I blew more time than I'd like to admit writing this, so might as well post it here. Here's a function that incorporates the suggestions from above that (for the moment) works for rstanarm and brms model objects.

compare_posteriors <- function(..., dodge_width = 0.5) {
  dots <- rlang::dots_list(..., .named = TRUE)
  draws <- lapply(dots, function(x) {
    if (class(x)[1] == "stanreg") {
        posterior::subset_draws(posterior::as_draws(x$stanfit),
            variable = names(fixef(x))
        )
    } else if (class(x)[1] == "brmsfit") {
        brm_draws <- posterior::subset_draws(posterior::as_draws(x$fit),
            variable = paste0("b_", rownames(fixef(x)))
        )
        posterior::variables(brm_draws) <- stringr::str_split(posterior::variables(brm_draws), "_", simplify = T)[, 2]
        posterior::rename_variables(brm_draws, `(Intercept)` = Intercept)
    } else {
        stop(paste0(class(x)[1], " objects not supported."))
    }
  })
  intervals <- lapply(draws, bayesplot::mcmc_intervals_data)
  combined <- dplyr::bind_rows(intervals, .id = "model")
  ggplot(combined, aes(x = m, y = parameter, color = model, group = model)) +
    geom_linerange(aes(xmin = l, xmax = h), size = 2, position = position_dodge(dodge_width)) +
    geom_linerange(aes(xmin = ll, xmax = hh), position = position_dodge(dodge_width)) +
    geom_point(color = "black", position = position_dodge(dodge_width)) +
    geom_vline(xintercept = 0, linetype = "dashed")
}

Usage:

compare_posteriors(mod1, mod2, mod3)
Rigger answered 20/2, 2022 at 21:25 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.