Sampling from a joint distribution in Pyro
Asked Answered
S

0

8

I understand how to sample from multidimensional categorical, or multivariate normal (with dependence within each column). For example, for a multivariate categorical, this can be done as below:

import pyro as p
import pyro.distributions as d
import torch as t
p.sample("obs1", d.Categorical(logits=logit_pobs1).independent(1), obs=t.t(obs1))

My question is, how can we do the same, if there are multiple distributions? For example, the following is not what I want as obs1, obs2 and obs3 are independent to each other.

p.sample("obs1", d.Categorical(logits=logit_pobs1).independent(1), obs=t.t(obs1))
p.sample("obs2", d.Normal(loc=mu_obs2, scale=t.ones(mu_obs2.shape)).independent(1), obs=t.t(obs2))
p.sample("obs3", d.Bernoulli(logits=logit_pobs3).independent(1),obs3)

I would like to do something like

p.sample("obs", d.joint(d.Bernoulli(...), d.Normal(...), d.Bernoulli(...)).independent(1),obs)
Swordsman answered 13/10, 2018 at 17:21 Comment(3)
Hi alpaca. Did you ever figure this problem out? I'm also trying to develop a 3 token language that has some probabilistic structure with order, and I think the joint prob distribution is what I need.Continuative
I am really interested as well. Any updates?Lennon
Well, sampling from a joint distribution depends pretty strong on assumptions about dependence, so, exactly what are you assuming about the dependencies among obs1, obs2, and obs3?Weakling

© 2022 - 2024 — McMap. All rights reserved.