Parallel RJAGS with convergence testing
Asked Answered
S

1

6

I'm modifying an existing model using RJAGS. I'd like to run chains in parallel, and occasionally check the Gelman-Rubin convergence diagnostic to see if I need to keep running. The problem is, if I need to resume running based on the diagnostic value, the recompiled chains restart from the first initialized prior values and not the position in parameter space where the chain stopped. If I do not recompile the model, RJAGS complains. Is there a way to store the positions of the chains when they stop so I can re-initialize from where I left off? Here I'll give a very simplified example.

example1.bug:

model {
  for (i in 1:N) {
      x[i] ~ dnorm(mu,tau)
  }
  mu ~ dnorm(0,0.0001)
  tau <- pow(sigma,-2)
  sigma ~ dunif(0,100)
}

parallel_test.R:

#Make some fake data
N <- 1000
x <- rnorm(N,0,5)
write.table(x,
        file='example1.data',
        row.names=FALSE,
        col.names=FALSE)

library('rjags')
library('doParallel')
library('random')

nchains <- 4
c1 <- makeCluster(nchains)
registerDoParallel(c1)

jags=list()
for (i in 1:getDoParWorkers()){
  jags[[i]] <- jags.model('example1.bug',
                          data=list('x'=x,'N'=N))
}

# Function to combine multiple mcmc lists into a single one
mcmc.combine <- function( ... ){
  return( as.mcmc.list( sapply( list( ... ),mcmc ) ) )
}

#Start with some burn-in
jags.parsamples <- foreach( i=1:getDoParWorkers(),
                           .inorder=FALSE,
                           .packages=c('rjags','random'),
                           .combine='mcmc.combine',
                           .multicombine=TRUE) %dopar%
{
  jags[[i]]$recompile()

  update(jags[[i]],100)
  jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)

  return(jags.samples)
}   

#Check the diagnostic output
print(gelman.diag(jags.parsamples[,'mu']))

counter <- 0

#my model doesn't converge so quickly, so let's simulate doing
#this updating 5 times:
#while(gelman.diag(jags.parsamples[,'mu'])[[1]][[2]] > 1.04)
while(counter < 5)
{
counter <- counter + 1
jags.parsamples <- foreach(i=1:getDoParWorkers(),
                             .inorder=FALSE,
                             .packages=c('rjags','random'),
                             .combine='mcmc.combine',
                             .multicombine=TRUE) %dopar%
  {
    #Here I lose the progress I've made
    jags[[i]]$recompile()
    jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)
    return(jags.samples)
  }
}

print(gelman.diag(jags.parsamples[,'mu']))
print(summary(jags.parsamples))
stopCluster(c1)

In the output, I see:

Iterations = 1001:2000

where I know there should be > 5000 iterations. (cross-posted to stats.stackexchange.com, which may be the more appropriate venue)

Superphosphate answered 6/4, 2015 at 20:18 Comment(1)
I'm finding R2jags::jags.parallel works great for the parallelization part. Have the other parts of this (checking for convergence, picking up where the chain left off if not met) still the same, or are there new tools to do this?Ezmeralda
C
5

Every time your JAGS model runs on the worker nodes the coda samples are returned but the state of the model is lost. So next time it recompiles, it restarts from the beginning, as you are seeing. To get around this you need to get and return the state of the model in your function (on the worker nodes) like so:

 endstate <- jags[[i]]$state(internal=TRUE)

Then you need to pass this back to the worker node and re-generate the model within the worker function using jags.model() with inits=endstate (for the appropriate chain).

I would actually recommend looking at the runjags package that does all this for you. For example:

library('runjags')
parsamples <- run.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), sample=100, method='rjparallel')
summary(parsamples)
newparsamples <- extend.jags(parsamples, sample=100)
summary(parsamples)
# etc

Or even:

parsamples <- autorun.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), method='rjparallel')

Version 2 of runjags will hopefully be uploaded to CRAN soon, but for now you can download binaries from: https://sourceforge.net/projects/runjags/files/runjags/

Matt

Condominium answered 7/4, 2015 at 5:17 Comment(5)
I couldn't quite get it to work passing init=endstate[[i]] on subsequent passes (I wound up seeing identical traces each time), but my simple example did work very well when I ran with runjags. It also appears that it will fit quite painlessly into my much larger and more complicated model. Is it acceptable to trust runjags v1.2.1 for now?Superphosphate
I only remember fixing one bug (which isn't relevant to your model) relating to the rjparallel method with version 2, so it should be fine. The main advantage to upgrading is improved plotting and summary facilities which are additional/improved features rather than bug fixes.Condominium
Getting an error on extend.jags after successfully running a model with run.jags: Error: unused argument(s) 'samples' (no unambiguous match in the 'extend.jags' or 'add.summary' functions)Tobietobin
Are you sure it's not something simple like 'samples'->'sample' ?Superphosphate
Yes the argument should be 'sample' - I fixed the typo in the answer and left a comment several days ago, but it seems the comment didn't post for some reason - sorry!Condominium

© 2022 - 2024 — McMap. All rights reserved.