Matrix with simplex columns in stan
Asked Answered
Z

1

5

Is there a way to construct a matrix with simplex columns in Stan? The model I want to construct is similar to the following, where I model counts as dirichlet-multinomial:

data {
  int g;
  int c;
  int<lower=0> counts[g, c];
}

parameters {
  simplex [g] p;
}

model {
  for (j in 1:c) {
    p ~ dirichlet(rep_vector(1.0, g));
    counts[, j] ~ multinomial(p);
  }
}

However I would like to use a latent [g, c] matrix for further layers of a hierarchical model similar to the following:

parameters {
  // simplex_matrix would have columns which are each a simplex.
  simplex_matrix[g, c] p;
}
model {
  for (j in 1:c) {
    p[, j] ~ dirichlet(rep_vector(1.0, g));
    counts[, j] ~ multinomial(p[, j]);
  }
}

If there's another way to construct this latent variable that would of course also be great! I'm not massively familiar with stan having only implemented a few hierarchical models.

Zelazny answered 1/10, 2019 at 20:3 Comment(0)
I
6

To answer the questions that you asked, you can declare an array of simplexes in the parameter block of a Stan program and use them to fill a matrix. For example,

parameters {
  simplex[g] p[c];
}
model {
  matrix[g, c] col_stochastic_matrix;
  for (i in 1:c) col_stochastic_matrix[,c] = p[c];
}

However, you do not actually need to form a column stochastic matrix in the example you gave, since you can do the multinomial-Dirichlet model by indexing an array of simplexes like

data {
  int g;
  int c;
  int<lower=0> counts[g, c];
}
parameters {
  simplex [g] p[c];
}
model {
  for (j in 1:c) {
    p[j] ~ dirichlet(rep_vector(1.0, g));
    counts[, j] ~ multinomial(p[j]);
  }
}

Finally, you do not actually need to declare an array of simplexes at all, since they can be integrated out of the posterior distribution and recovered in the generated quantities block of the Stan program. See wikipedia for details but the essence of it is given by this Stan function

functions {
  real DM_lpmf(int [] n, vector alpha) {
    int N = sum(n);
    real A = sum(alpha);
    return lgamma(A) - lgamma(N + A) 
           + sum(lgamma(to_vector(n) + alpha)
           - sum(lgamma(alpha));
  }
}
Innuendo answered 1/10, 2019 at 20:37 Comment(2)
That is an absolutely amazing answer, thank you so much. I figured that integrating out the simplexes would be the ultimate answer, but given that I'm at the prototyping stage I may stick with simplexes for the time being.Zelazny
Just returned to this and realised I had misunderstood your final point. The model is identical and about 6 times faster. Thanks again!Zelazny

© 2022 - 2024 — McMap. All rights reserved.