How to make a truncated normal prior: converting pymc2 to pymc3
Asked Answered
P

3

9

In pymc3 how does one configure a truncated normal prior? In pymc2 it's pretty straightforward (below), but in pymc3 it seems there is no longer a truncated normal distribution available.

Pymc2:

TruncatedNormal('gamma_own_%i_' % i, mu=go, tau=v_gamma_inv, value=0, a=-np.inf, b=0)

Pymc3: ?

Poky answered 18/9, 2015 at 3:49 Comment(1)
It looks like the TruncatedNormal has been added to PyMC3 in the time since this question was asked. I'll add a full answer, but in case it is helpful to others who have this question (as I just did), you can use pm.TruncatedNormal('n', mu=0, tau=10, lower=0, upper=1) in PyMC3 now.Johnsiejohnson
P
11

In PyMC3 you are able to truncate any distribution using Bound. First you have to construct the bounded distribution (here called BoundedNormal), then create a variable where you input the usual parameters of the underlying distribution:

with pm.Model() as model:
    BoundedNormal = pm.Bound(pm.Normal, lower=0, upper=1)
    n = BoundedNormal('n', mu=0, tau=10)
    tr = pm.sample(2000, pm.NUTS())

The resulting distribution looks like this: KDE and trace of the bounded normal distribution

Peckham answered 18/9, 2015 at 14:48 Comment(1)
For people finding this now. If you're getting an AttributeError: 'bool' object has no attribute 'any'. Then specify a testval in the BoundedNormal instantiation.Chloro
R
2

The following code works in Pymc3 version 3.0

   a, b=np.float32(0.0), np.float32(10.0)

   K_lo,  K_hi = 0.0, 1.0

   BoundedNormal = pm.Bound(pm.Normal, lower=K_lo, upper=K_hi)

   with pm.Model() as model:

          n = BoundedNormal('n', mu=a, tau=b)

          tr = pm.sample(2000, pm.NUTS())

          traceplot(tr)
Reserved answered 27/4, 2017 at 14:31 Comment(0)
J
0

Here is the full code for a TruncatedNormal similar to the previous solutions from Meysam Hashemi and Kiudee, from PyMC3 version 3.6:

import pymc3 as pm

with pm.Model() as model:
    n = pm.TruncatedNormal('n', mu=0, tau=10, lower=0, upper=1)
    tr = pm.sample(2000, pm.NUTS())
Johnsiejohnson answered 17/1, 2019 at 4:45 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.