Representing Parametric Survival Model in 'Counting Process' form in JAGS
Asked Answered
V

1

396

I'm trying to build a survival model in JAGS that allows for time-varying covariates. I'd like it to be a parametric model — for example, assuming survival follows the Weibull distribution (but I'd like to allow the hazard to vary, so exponential is too simple). So, this is essentially a Bayesian version of what can be done in the flexsurv package, which allows for time-varying covariates in parametric models.

Therefore, I want to be able to enter the data in a 'counting-process' form, where each subject has multiple rows, each corresponding to a time interval in which their covariates remained constant (as described in this pdf or here. This is the (start, stop] formulation that the survival or flexurv packages allow.

Unfortunately, every explanation of how to perform survival analysis in JAGS seems to assume one row per subject.

I attempted to take this simpler approach and extend it to the counting process format, but the model does not correctly estimate the distribution.

A Failed Attempt:

Here's an example. First we generate some data:

library('dplyr')
library('survival')

## Make the Data: -----
set.seed(3)
n_sub <- 1000
current_date <- 365*2

true_shape <- 2
true_scale <- 365

dat <- data_frame(person = 1:n_sub,
                  true_duration = rweibull(n = n_sub, shape = true_shape, scale = true_scale),
                  person_start_time = runif(n_sub, min= 0, max= true_scale*2),
                  person_censored = (person_start_time + true_duration) > current_date,
                  person_duration = ifelse(person_censored, current_date - person_start_time, true_duration)
)

  person person_start_time person_censored person_duration
   (int)             (dbl)           (lgl)           (dbl)
1      1          11.81416           FALSE        487.4553
2      2         114.20900           FALSE        168.7674
3      3          75.34220           FALSE        356.6298
4      4         339.98225           FALSE        385.5119
5      5         389.23357           FALSE        259.9791
6      6         253.71067           FALSE        259.0032
7      7         419.52305            TRUE        310.4770

Then we split the data into 2 observations per subject. I'm just splitting each subject at time = 300 (unless they didn't make it to time=300, in which they get just one observation).

## Split into multiple observations per person: --------
cens_point <- 300 # <----- try changing to 0 for no split; if so, model correctly estimates
dat_split <- dat %>%
  group_by(person) %>%
  do(data_frame(
    split = ifelse(.$person_duration > cens_point, cens_point, .$person_duration),
    START = c(0, split[1]),
    END = c(split[1], .$person_duration),
    TINTERVAL = c(split[1], .$person_duration - split[1]), 
    CENS = c(ifelse(.$person_duration > cens_point, 1, .$person_censored), .$person_censored), # <— edited original post here due to bug; but problem still present when fixing bug
    TINTERVAL_CENS = ifelse(CENS, NA, TINTERVAL),
    END_CENS = ifelse(CENS, NA, END)
  )) %>%
  filter(TINTERVAL != 0)

  person    split START      END TINTERVAL CENS TINTERVAL_CENS
   (int)    (dbl) (dbl)    (dbl)     (dbl) (dbl)        (dbl)
1      1 300.0000     0 300.0000 300.00000     1           NA
2      1 300.0000   300 487.4553 187.45530     0    187.45530
3      2 168.7674     0 168.7674 168.76738     1           NA
4      3 300.0000     0 300.0000 300.00000     1           NA
5      3 300.0000   300 356.6298  56.62979     0     56.62979
6      4 300.0000     0 300.0000 300.00000     1           NA

Now we can set up the JAGS model.

## Set-Up JAGS Model -------
dat_jags <- as.list(dat_split)
dat_jags$N <- length(dat_jags$TINTERVAL)
inits <- replicate(n = 2, simplify = FALSE, expr = {
       list(TINTERVAL_CENS = with(dat_jags, ifelse(CENS, TINTERVAL + 1, NA)),
            END_CENS = with(dat_jags, ifelse(CENS, END + 1, NA)) )
})

model_string <- 
"
  model {
    # set priors on reparameterized version, as suggested
    # here: https://sourceforge.net/p/mcmc-jags/discussion/610036/thread/d5249e71/?limit=25#8c3b
    log_a ~ dnorm(0, .001) 
    log(a) <- log_a
    log_b ~ dnorm(0, .001)
    log(b) <- log_b
    nu <- a
    lambda <- (1/b)^a
    
    for (i in 1:N) {
      # Estimate Subject-Durations:
      CENS[i] ~ dinterval(TINTERVAL_CENS[i], TINTERVAL[i])
      TINTERVAL_CENS[i] ~ dweibull( nu, lambda )
    }
  }
"

library('runjags')
param_monitors <- c('a', 'b', 'nu', 'lambda') 
fit_jags <- run.jags(model = model_string,
                     burnin = 1000, sample = 1000, 
                     monitor = param_monitors,
                     n.chains = 2, data = dat_jags, inits = inits)
# estimates:
fit_jags
# actual:
c(a=true_shape, b=true_scale)

Depending on where the split point is, the model estimates very different parameters for the underlying distribution. It only gets the parameters right if the data isn't split into the counting process form. It seems like this is not the way to format the data for this kind of problem.

If I am missing an assumption and my problem is less related to JAGS and more related to how I'm formulating the problem, suggestions are very welcome. I might be despairing that time-varying covariates can't be used in parametric survival models (and can only be used in models like the Cox model, which assumes constant hazards and which doesn't actually estimate the underlying distribution)— however, as I mentioned above, the flexsurvreg package in R does accommodate the (start, stop] formulation in parametric models.

If anyone knows how to build a model like this in another language (e.g. STAN instead of JAGS) that would be appreciated too.

Edit:

Chris Jackson provides some helpful advice via email:

I think the T() construct for truncation in JAGS is needed here. Essentially for each period (t[i], t[i+1]) where a person is alive but the covariate is constant, the survival time is left-truncated at the start of the period, and possibly also right-censored at the end. So you'd write something like y[i] ~ dweib(shape, scale[i])T(t[i], )

I tried implementing this suggestion as follows:

model {
  # same as before
  log_a ~ dnorm(0, .01) 
  log(a) <- log_a
  log_b ~ dnorm(0, .01)
  log(b) <- log_b
  nu <- a
  lambda <- (1/b)^a
  
  for (i in 1:N) {
    # modified to include left-truncation
    CENS[i] ~ dinterval(END_CENS[i], END[i])
    END_CENS[i] ~ dweibull( nu, lambda )T(START[i],)
  }
}

Unfortunately this doesn't quite do the trick. With the old code, the model was mostly getting the scale parameter right, but doing a very bad job on the shape parameter. With this new code, it gets very close to the correct shape parameter, but consistently over-estimates the scale parameter. I have noticed that the degree of over-estimation is correlated with how late the split point comes. If the split-point is early (cens_point = 50), there's not really any over-estimation; if it's late (cens_point = 350), there is a lot.

I thought maybe the problem could be related to 'double-counting' the observations: if we see a censored observation at t=300, then from that same person, an uncensored observation at t=400, it seems intuitive to me that this person is contributing two data-points to our inference about the Weibull parameters when really they should just be contributing one point. I, therefore, tried incorporating a random-effect for each person; however, this completely failed, with huge estimates (in the 50-90 range) for the nu parameter. I'm not sure why that is, but perhaps that's a question for a separate post. Since I'm not whether the problems are related, you can find the code for this whole post, including the JAGS code for that model, here.

Voight answered 20/4, 2016 at 4:8 Comment(9)
I'm wondering if this could be considered on-topic for CrossValidated.com. They have a rule that single language questions should not be posed there, but perhaps this could be framed as how one could properly build a Bayesian algorithm?Mighty
The Sensitivity on the cut-point is to be expected. You are probably dealing with non-proportional hazards. Even with a semiparametric model like coxph the hazard-ratios will be sensitive to the choice of the cut point under non-proportional hazards. As @Mighty suggested, maybe this could be a good question for crossvalidatedHalidom
@eno gerguri I found this paper that references counting process in Jags with some code in the appendix: ncbi.nlm.nih.gov/pmc/articles/PMC3998726 is this what you are looking for?Tomi
@teunbrand I think it's quite a leap from "this question is not about programming as defined by the help center" to "no questions about R are about programming as defined by the help center".Syphilology
@teunbrand Sure, see the bullet point list at the top of stackoverflow.com/help/on-topic. The question here doesn't seem to be about a specific programming aspect of R or the JAGS library, but rather about fitting models of expected output (e.g. I'm not getting the results I want). Maybe it's just unclear to me what it's about (I know some R but not any JAGS/Bayesian modeling), but it seems like too broad an application of the "there's code so it's a programming problem" mindset to me.Syphilology
@teunbrand I suspect it could get a great in-depth answer on one of the statistics or data science stacks, given it's clear popularity, but here it doesn't seem answerable in a way that's in scope for the site, given that it... hasn't been answered yet, despite massive popularity. Regardless, you're not required to vote to close if you don't agree, naturally.Syphilology
@Syphilology Yes I saw the on-topic bullets, but I guess it depends on the POV then, because one could perceive this question as enquiring about the effective use of modelling software algorithms, which would put it in bullet #2 of the on-topic list. Anyway I'm sorry my comments have made the discussion off-topic, so I'll be deleting my previous ones after a few days to not pollute more relevant discussion. For posterity, I was disagreeing with the vote to close.Nepil
@jwdink, I had to fit a similar model in JAGS recently. Give me a few days to type up a response.Foin
The Aalen model suppose to be that the cumulative hazard H(t) for a subject can be expressed as a(t) + X B(t), where a(t) is a time-dependent intercept term, X is the vector of covariates for the subject (possibly time-dependent), and B(t) is a time-dependent matrix of coefficients. The estimates are inherently non-parametric; a fit of the model will normally be followed by one or more plots of the estimates. Extracted: Aalen, O.O. (1989). A linear regression model for the analysis of life times. Statistics in MedicineAnthrax
P
5

You can use rstanarm package, which is a wrapper around STAN. It allows to use standard R formula notation to describe survival models. stan_surv function accepts arguments in a "counting process" form. Different base hazard functions including Weibull can be used to fit the model.

The survival part of rstanarm - stan_surv function is still not available at CRAN so you should install the package directly from mc-stan.org.

install.packages("rstanarm", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))

Please see the code below:

library(dplyr)
library(survival)
library(rstanarm)

## Make the Data: -----
set.seed(3)
n_sub <- 1000
current_date <- 365*2

true_shape <- 2
true_scale <- 365

dat <- data_frame(person = 1:n_sub,
                  true_duration = rweibull(n = n_sub, shape = true_shape, scale = true_scale),
                  person_start_time = runif(n_sub, min= 0, max= true_scale*2),
                  person_censored = (person_start_time + true_duration) > current_date,
                  person_duration = ifelse(person_censored, current_date - person_start_time, true_duration)
)

## Split into multiple observations per person: --------
cens_point <- 300 # <----- try changing to 0 for no split; if so, model correctly estimates
dat_split <- dat %>%
  group_by(person) %>%
  do(data_frame(
    split = ifelse(.$person_duration > cens_point, cens_point, .$person_duration),
    START = c(0, split[1]),
    END = c(split[1], .$person_duration),
    TINTERVAL = c(split[1], .$person_duration - split[1]), 
    CENS = c(ifelse(.$person_duration > cens_point, 1, .$person_censored), .$person_censored), # <— edited original post here due to bug; but problem still present when fixing bug
    TINTERVAL_CENS = ifelse(CENS, NA, TINTERVAL),
    END_CENS = ifelse(CENS, NA, END)
  )) %>%
  filter(TINTERVAL != 0)
dat_split$CENS <- as.integer(!(dat_split$CENS))


# Fit STAN survival model
mod_tvc <- stan_surv(
  formula = Surv(START, END, CENS) ~ 1,
  data = dat_split,
  iter = 1000,
  chains = 2,
  basehaz = "weibull-aft")

# Print fit coefficients
mod_tvc$coefficients[2]
unname(exp(mod_tvc$coefficients[1]))

Output, which is consistent with true values (true_shape <- 2; true_scale <- 365):

> mod_tvc$coefficients[2]
weibull-shape 
     1.943157 
> unname(exp(mod_tvc$coefficients[1]))
[1] 360.6058

You can also look at STAN source using rstan::get_stanmodel(mod_tvc$stanfit) to compare STAN code with the attempts you made in JAGS.

Polash answered 27/12, 2021 at 9:46 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.