I'm modifying an existing model using RJAGS. I'd like to run chains in parallel, and occasionally check the Gelman-Rubin convergence diagnostic to see if I need to keep running. The problem is, if I need to resume running based on the diagnostic value, the recompiled chains restart from the first initialized prior values and not the position in parameter space where the chain stopped. If I do not recompile the model, RJAGS complains. Is there a way to store the positions of the chains when they stop so I can re-initialize from where I left off? Here I'll give a very simplified example.
example1.bug:
model {
for (i in 1:N) {
x[i] ~ dnorm(mu,tau)
}
mu ~ dnorm(0,0.0001)
tau <- pow(sigma,-2)
sigma ~ dunif(0,100)
}
parallel_test.R:
#Make some fake data
N <- 1000
x <- rnorm(N,0,5)
write.table(x,
file='example1.data',
row.names=FALSE,
col.names=FALSE)
library('rjags')
library('doParallel')
library('random')
nchains <- 4
c1 <- makeCluster(nchains)
registerDoParallel(c1)
jags=list()
for (i in 1:getDoParWorkers()){
jags[[i]] <- jags.model('example1.bug',
data=list('x'=x,'N'=N))
}
# Function to combine multiple mcmc lists into a single one
mcmc.combine <- function( ... ){
return( as.mcmc.list( sapply( list( ... ),mcmc ) ) )
}
#Start with some burn-in
jags.parsamples <- foreach( i=1:getDoParWorkers(),
.inorder=FALSE,
.packages=c('rjags','random'),
.combine='mcmc.combine',
.multicombine=TRUE) %dopar%
{
jags[[i]]$recompile()
update(jags[[i]],100)
jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)
return(jags.samples)
}
#Check the diagnostic output
print(gelman.diag(jags.parsamples[,'mu']))
counter <- 0
#my model doesn't converge so quickly, so let's simulate doing
#this updating 5 times:
#while(gelman.diag(jags.parsamples[,'mu'])[[1]][[2]] > 1.04)
while(counter < 5)
{
counter <- counter + 1
jags.parsamples <- foreach(i=1:getDoParWorkers(),
.inorder=FALSE,
.packages=c('rjags','random'),
.combine='mcmc.combine',
.multicombine=TRUE) %dopar%
{
#Here I lose the progress I've made
jags[[i]]$recompile()
jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)
return(jags.samples)
}
}
print(gelman.diag(jags.parsamples[,'mu']))
print(summary(jags.parsamples))
stopCluster(c1)
In the output, I see:
Iterations = 1001:2000
where I know there should be > 5000 iterations. (cross-posted to stats.stackexchange.com, which may be the more appropriate venue)
R2jags::jags.parallel
works great for the parallelization part. Have the other parts of this (checking for convergence, picking up where the chain left off if not met) still the same, or are there new tools to do this? – Ezmeralda