I'm trying to build a survival model in JAGS that allows for time-varying covariates. I'd like it to be a parametric model — for example, assuming survival follows the Weibull distribution (but I'd like to allow the hazard to vary, so exponential is too simple). So, this is essentially a Bayesian version of what can be done in the flexsurv
package, which allows for time-varying covariates in parametric models.
Therefore, I want to be able to enter the data in a 'counting-process' form, where each subject has multiple rows, each corresponding to a time interval in which their covariates remained constant (as described in this pdf or here. This is the (start, stop]
formulation that the survival
or flexurv
packages allow.
Unfortunately, every explanation of how to perform survival analysis in JAGS seems to assume one row per subject.
I attempted to take this simpler approach and extend it to the counting process format, but the model does not correctly estimate the distribution.
A Failed Attempt:
Here's an example. First we generate some data:
library('dplyr')
library('survival')
## Make the Data: -----
set.seed(3)
n_sub <- 1000
current_date <- 365*2
true_shape <- 2
true_scale <- 365
dat <- data_frame(person = 1:n_sub,
true_duration = rweibull(n = n_sub, shape = true_shape, scale = true_scale),
person_start_time = runif(n_sub, min= 0, max= true_scale*2),
person_censored = (person_start_time + true_duration) > current_date,
person_duration = ifelse(person_censored, current_date - person_start_time, true_duration)
)
person person_start_time person_censored person_duration
(int) (dbl) (lgl) (dbl)
1 1 11.81416 FALSE 487.4553
2 2 114.20900 FALSE 168.7674
3 3 75.34220 FALSE 356.6298
4 4 339.98225 FALSE 385.5119
5 5 389.23357 FALSE 259.9791
6 6 253.71067 FALSE 259.0032
7 7 419.52305 TRUE 310.4770
Then we split the data into 2 observations per subject. I'm just splitting each subject at time = 300 (unless they didn't make it to time=300, in which they get just one observation).
## Split into multiple observations per person: --------
cens_point <- 300 # <----- try changing to 0 for no split; if so, model correctly estimates
dat_split <- dat %>%
group_by(person) %>%
do(data_frame(
split = ifelse(.$person_duration > cens_point, cens_point, .$person_duration),
START = c(0, split[1]),
END = c(split[1], .$person_duration),
TINTERVAL = c(split[1], .$person_duration - split[1]),
CENS = c(ifelse(.$person_duration > cens_point, 1, .$person_censored), .$person_censored), # <— edited original post here due to bug; but problem still present when fixing bug
TINTERVAL_CENS = ifelse(CENS, NA, TINTERVAL),
END_CENS = ifelse(CENS, NA, END)
)) %>%
filter(TINTERVAL != 0)
person split START END TINTERVAL CENS TINTERVAL_CENS
(int) (dbl) (dbl) (dbl) (dbl) (dbl) (dbl)
1 1 300.0000 0 300.0000 300.00000 1 NA
2 1 300.0000 300 487.4553 187.45530 0 187.45530
3 2 168.7674 0 168.7674 168.76738 1 NA
4 3 300.0000 0 300.0000 300.00000 1 NA
5 3 300.0000 300 356.6298 56.62979 0 56.62979
6 4 300.0000 0 300.0000 300.00000 1 NA
Now we can set up the JAGS model.
## Set-Up JAGS Model -------
dat_jags <- as.list(dat_split)
dat_jags$N <- length(dat_jags$TINTERVAL)
inits <- replicate(n = 2, simplify = FALSE, expr = {
list(TINTERVAL_CENS = with(dat_jags, ifelse(CENS, TINTERVAL + 1, NA)),
END_CENS = with(dat_jags, ifelse(CENS, END + 1, NA)) )
})
model_string <-
"
model {
# set priors on reparameterized version, as suggested
# here: https://sourceforge.net/p/mcmc-jags/discussion/610036/thread/d5249e71/?limit=25#8c3b
log_a ~ dnorm(0, .001)
log(a) <- log_a
log_b ~ dnorm(0, .001)
log(b) <- log_b
nu <- a
lambda <- (1/b)^a
for (i in 1:N) {
# Estimate Subject-Durations:
CENS[i] ~ dinterval(TINTERVAL_CENS[i], TINTERVAL[i])
TINTERVAL_CENS[i] ~ dweibull( nu, lambda )
}
}
"
library('runjags')
param_monitors <- c('a', 'b', 'nu', 'lambda')
fit_jags <- run.jags(model = model_string,
burnin = 1000, sample = 1000,
monitor = param_monitors,
n.chains = 2, data = dat_jags, inits = inits)
# estimates:
fit_jags
# actual:
c(a=true_shape, b=true_scale)
Depending on where the split point is, the model estimates very different parameters for the underlying distribution. It only gets the parameters right if the data isn't split into the counting process form. It seems like this is not the way to format the data for this kind of problem.
If I am missing an assumption and my problem is less related to JAGS and more related to how I'm formulating the problem, suggestions are very welcome. I might be despairing that time-varying covariates can't be used in parametric survival models (and can only be used in models like the Cox model, which assumes constant hazards and which doesn't actually estimate the underlying distribution)— however, as I mentioned above, the flexsurvreg
package in R does accommodate the (start, stop]
formulation in parametric models.
If anyone knows how to build a model like this in another language (e.g. STAN instead of JAGS) that would be appreciated too.
Edit:
Chris Jackson provides some helpful advice via email:
I think the T() construct for truncation in JAGS is needed here. Essentially for each period (t[i], t[i+1]) where a person is alive but the covariate is constant, the survival time is left-truncated at the start of the period, and possibly also right-censored at the end. So you'd write something like
y[i] ~ dweib(shape, scale[i])T(t[i], )
I tried implementing this suggestion as follows:
model {
# same as before
log_a ~ dnorm(0, .01)
log(a) <- log_a
log_b ~ dnorm(0, .01)
log(b) <- log_b
nu <- a
lambda <- (1/b)^a
for (i in 1:N) {
# modified to include left-truncation
CENS[i] ~ dinterval(END_CENS[i], END[i])
END_CENS[i] ~ dweibull( nu, lambda )T(START[i],)
}
}
Unfortunately this doesn't quite do the trick. With the old code, the model was mostly getting the scale parameter right, but doing a very bad job on the shape parameter. With this new code, it gets very close to the correct shape parameter, but consistently over-estimates the scale parameter. I have noticed that the degree of over-estimation is correlated with how late the split point comes. If the split-point is early (cens_point = 50
), there's not really any over-estimation; if it's late (cens_point = 350
), there is a lot.
I thought maybe the problem could be related to 'double-counting' the observations: if we see a censored observation at t=300, then from that same person, an uncensored observation at t=400, it seems intuitive to me that this person is contributing two data-points to our inference about the Weibull parameters when really they should just be contributing one point. I, therefore, tried incorporating a random-effect for each person; however, this completely failed, with huge estimates (in the 50-90 range) for the nu
parameter. I'm not sure why that is, but perhaps that's a question for a separate post. Since I'm not whether the problems are related, you can find the code for this whole post, including the JAGS code for that model, here.
coxph
the hazard-ratios will be sensitive to the choice of the cut point under non-proportional hazards. As @Mighty suggested, maybe this could be a good question for crossvalidated – HalidomExtracted: Aalen, O.O. (1989). A linear regression model for the analysis of life times. Statistics in Medicine
– Anthrax