pymc3 : Dirichlet with multidimensional concentration factor
Asked Answered
M

1

5

I am struggling with implementing a model where the concentration factor of the Dirichlet variable is dependent on another variable.

The situation is the following:

A system fails due to faulty components (there are three components, only one fails at each test/observation).

The probability of failure of the components is dependent on the temperature.

Here is a (commented) short implementation of the situation:

import numpy as np
import pymc3 as pm
import theano.tensor as tt


# Temperature data : 3 cold temperatures and 3 warm temperatures
T_data = np.array([10, 12, 14, 80, 90, 95])

# Data of failures of 3 components : [0,0,1] means component 3 failed
F_data = np.array([[0, 0, 1], \
       [0, 0, 1], \
       [0, 0, 1], \
       [1, 0, 0], \
       [1, 0, 0], \
       [1, 0, 0]])

n_component = 3

# When temperature is cold : Component 1 fails
# When temperature is warm : Component 3 fails
# Component 2 never fails

# Number of observations :
n_obs = len(F_data)


# The number of failures can be modeled as a Multinomial F ~ M(n_obs, p) with parameters 
# -  n_test : number of tests (Fixed)
# -  p : probability of failure of each component (shape (n_obs, 3))

# The probability of failure of components follows a Dirichlet distribution p ~ Dir(alpha) with parameters:
# -  alpha : concentration (shape (n_obs, 3))
# The Dirichlet distributions ensures the probabilities sum to 1 

# The alpha parameters (and the the probability of failures) depend on the temperature alpha ~ a + b * T
# - a : bias term (shape (1,3))
# - b : describes temperature dependency of alpha (shape (1,3))

_

# The prior on "a" is a normal distributions with mean 1/2 and std 0.001
# a ~ N(1/2, 0.001)

# The prior on "b" is a normal distribution zith mean 0 and std 0.001
# b ~ N(0, 0.001)


# Coding it all with pymc3

with pm.Model() as model:
    a = pm.Normal('a', 1/2, 1/(0.001**2), shape = n_component)
    b = pm.Normal('b', 0, 1/(0.001**2), shape = n_component)

    # I generate 3 alphas values (corresponding to the 3 components) for each of the 6 temperatures
    # I tried different ways to compute alpha but nothing worked out
    alphas = pm.Deterministic('alphas', a + b * tt.stack([T_data, T_data, T_data], axis=1))
    #alphas = pm.Deterministic('alphas', a + b[None, :] * T_data[:, None])
    #alphas = pm.Deterministic('alphas', a + tt.outer(T_data,b))


    # I think I should get 3 probabilities (corresponding to the 3 components) for each of the 6 temperatures
    #p = pm.Dirichlet('p', alphas, shape = n_component)
    p = pm.Dirichlet('p', alphas, shape = (n_obs,n_component))

    # Multinomial is observed and take values from F_data
    F = pm.Multinomial('F', 1, p, observed = F_data)


with model:
    trace = pm.sample(5000)

I get the following error in the sample function:


RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 73, in run
    self._start_loop()
  File "/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 113, in _start_loop
    point, stats = self._compute_point()
  File "/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 139, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/anaconda3/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 247, in step
    apoint, stats = self.astep(array)
  File "/anaconda3/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 117, in astep
    'might be misspecified.' % start.energy)
ValueError: Bad initial energy: inf. The model might be misspecified.
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: Bad initial energy: inf. The model might be misspecified.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-5-121fdd564b02> in <module>()
      1 with model:
      2     #start = pm.find_MAP()
----> 3     trace = pm.sample(5000)

/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    438             _print_step_hierarchy(step)
    439             try:
--> 440                 trace = _mp_sample(**sample_args)
    441             except pickle.PickleError:
    442                 _log.warning("Could not pickle model, sampling singlethreaded.")

/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)
    988         try:
    989             with sampler:
--> 990                 for draw in sampler:
    991                     trace = traces[draw.chain - chain]
    992                     if trace.supports_sampler_stats and draw.stats is not None:

/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py in __iter__(self)
    303 
    304         while self._active:
--> 305             draw = ProcessAdapter.recv_draw(self._active)
    306             proc, is_last, draw, tuning, stats, warns = draw
    307             if self._progress is not None:

/anaconda3/lib/python3.6/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
    221         if msg[0] == 'error':
    222             old = msg[1]
--> 223             six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
    224         elif msg[0] == 'writing_done':
    225             proc._readable = True

/anaconda3/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

RuntimeError: Chain 1 failed.

Any suggestions ?

Mendacious answered 20/9, 2018 at 16:36 Comment(1)
It looks like you're entering the standard deviation of the normals as though it were a precision, but the signature for those functions expect that argument to be simply the standard deviation. I would try explicitly using sd=0.001 or tau=1/0.001**2.Masterson
M
6

Misspecified model. The alphas are taking on nonpositive values under your current parameterization, whereas the Dirichlet distribution requires them to be positive, making the model misspecified.

In Dirichlet-Multinomial regression, one uses an exponential link function to mediate between the range of the linear model and the domain of the Dirichlet-Multinomial, namely,

alpha = exp(beta*X)

There are details on this in the MGLM package documentation.

Dirichlet-Multinomial Regression Model

If we implement this model we can achieve decent model convergence and sampling.

import numpy as np
import pymc3 as pm
import theano
import theano.tensor as tt
from sklearn.preprocessing import scale

T_data = np.array([10,12,14,80,90,95])

# standardize the data for better sampling
T_data_z = scale(T_data)

# transform to theano tensor, so it works with tt.outer
T_data_z = theano.shared(T_data_z)

F_data = np.array([
    [0,0,1],
    [0,0,1],
    [0,0,1],
    [1,0,0],
    [1,0,0],
    [1,0,0],
])

# N = num_obs, K = num_components
N, K = F_data.shape

with pm.Model() as dmr_model:
    a = pm.Normal('a', mu=0, sd=1, shape=K)
    b = pm.Normal('b', mu=0, sd=1, shape=K)

    alpha = pm.Deterministic('alpha', pm.math.exp(a + tt.outer(T_data_z, b)))

    p = pm.Dirichlet('p', a=alpha, shape=(N, K))

    F = pm.Multinomial('F', 1, p, observed=F_data)

    trace = pm.sample(5000, tune=10000, target_accept=0.9)

Model Outcomes

The sampling in this model isn't perfect. For example, there are still a number of divergences even with the increased target acceptance rate and additional tuning.

There were 501 divergences after tuning. Increase target_accept or reparameterize.

There were 477 divergences after tuning. Increase target_accept or reparameterize.

The acceptance probability does not match the target. It is 0.5858954056820339, but should be close to 0.8. Try to increase the number of tuning steps.

The number of effective samples is smaller than 10% for some parameters.

Trace Plots

We can see the traces for a and b look good, and the mean locations make sense with data.

enter image description here

Pair Plot

While correlation is less of a problem for NUTS, having uncorrelated posterior sampling is ideal. For the most part we're seeing low correlation, with some slight structure within the a components.

enter image description here

Posterior Plots

Finally, we can look at the posterior plots of p and confirm they make sense with the data.

enter image description here


Alternative Model

The advantage of the Dirichlet-Multinomial is handling overdispersion. It might be worth trying the simpler Multinomial Logisitic Regression / Softmax Regression, since it runs significantly faster and doesn't exhibit any of the sampling problems coming up in the DMR model.

In the end, you could run both and perform model comparison to see if the Dirichlet-Multinomial really is adding explanatory value.

Model

with pm.Model() as softmax_model:
    a = pm.Normal('a', mu=0, sd=1, shape=K)
    b = pm.Normal('b', mu=0, sd=1, shape=K)

    p = pm.Deterministic('p', tt.nnet.softmax(a + tt.outer(T_data_z, b)))

    F = pm.Multinomial('F', 1, p, observed = F_data)

    trace_sm = pm.sample(5000, tune=10000)

Posterior Plots

enter image description here

Masterson answered 20/9, 2018 at 23:46 Comment(3)
What command did you use to plot those figures (pair plot and posterior plot)?Havenot
@Havenot I used ArviZ for those plots.Masterson
The 'p' variables don't really seem to be the parameters of interest here, since the concentration parameters are the ones that are modeled. Perhaps convoluted, but it's possible to implement the PyMC3 distribution that represents the Marginal Multinomial-Dirichlet (so what we have here, but analytically marginalized over 'p'). That'll get rid of all the sampling issues, don't you think?Readership

© 2022 - 2024 — McMap. All rights reserved.