Eight Schools in NumPyro

numpyro
Bayesian
Robust HMC inference using variable reparameterisation
Published

June 22, 2022

This is an introduction to NumPyro, using the Eight Schools model as example.

Here we demonstrate the effects of model reparameterisation. Reparameterisation is especially important in hierarchical models, where the joint density tend to have high curvatures.

import numpy as np
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

rng_key = random.PRNGKey(0)
2024-09-10 12:53:32.254532: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

Here we are using the classic eight schools dataset from Gelman et al. We have collected the test score statistics for eight schools, including the mean and standard error. The goal, is to determine whether some schools have done better than others. Note that since we are working with the mean and standard error of eight different schools, we are actually modeling the statistical analysis resutls of some other people: this is essentially a meta analysis problem.

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

Visualize the data.

import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm

sns.set_theme()
sns.set_palette("Set2")

fig, ax = plt.subplots(figsize=(8, 4))
x = np.linspace(-50, 70, 1000)
for mean, sd in zip(y, sigma):
    ax.plot(x, norm.pdf(x, mean, sd))

plt.title('Test Score from Eight Schools')
plt.savefig('fig/eight.png')

The baseline model

The model we are building is a standard hierarchical model. We assume that the observed school means represent their true mean \(\theta_n\), but corrupted with some Normal noise, and the true means are themselves drawn from another district level distribution, with mean \(\mu\) and standard deviation \(\tau\), which are also modeled with suitable distributions. Essentially it’s Gaussian all the way up, and we have three different levels to consider: student, school, and the whole district level population.

\[ \begin{align*} y_n &\sim \text{N} (\theta_n, \sigma_n) \\ \theta_n &\sim \text{N} (\mu, \tau) \\ \mu &\sim \text{N} (0, 5) \\ \tau &\sim \text{HalfCauchy} (5). \end{align*} \]

In NumPyro models are coded as functions.

def es_0(J, sigma):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))

    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma))

Note that the code and the mathematical model are almost identical, except that they go in different directions. In the mathematical model, we start with the observed data, and reason backward to determine how they might be generated. In the code we start with the hyperparameters, and move forward to generate the observed data.

J and sigma are data we used to build our model, but they are not part of the model, in the sense that they are not assigned any probability distribution. y, on the other hand, is the central variable of the model, and is named obs in the model.

numpyro.plate is used to denote that the variables inside the plate are conditionally independent. Probability distributions are the building blocks of Bayesian models, and NumPyro has a lot of them. In NunPyro, probability distributions are wrappers of JAX random number generators, and they are translated into sampling statements using the numpyro.sample primitive.

However numpyro.sample is used to define the model, not to draw samples from the distribution. To actually draw samples from a distribution, we use numpyro.infer.Predictive. In numpyro each model defines a joint distribution, and since it’s a probability distribution, we can draw samples from it. And since we haven’t conditioned on any data, the samples we draw are from the prior distribution.

Sampling directly from the prior distribution, and inspect the samples, is a good way to check if the model is correctly defined.

es_prior_predictive_0 = Predictive(es_0, num_samples=1000)
es_prior_samples_0 = es_prior_predictive_0(rng_key, J, sigma)

def print_stats(name, samples):
    print(f"Variable: {name}")
    print(f"  Shape: {samples.shape}")
    print(f"  Mean: {jnp.mean(samples, axis=0)}")
    print(f"  Variance: {jnp.var(samples, axis=0)}\n")

# Print statistics for each variable
print_stats('mu', es_prior_samples_0['mu'])
print_stats('tau', es_prior_samples_0['tau'])
print_stats('theta', es_prior_samples_0['theta'])
print_stats('obs', es_prior_samples_0['obs'])
Variable: mu
  Shape: (1000,)
  Mean: -0.16857339441776276
  Variance: 26.169570922851562

Variable: tau
  Shape: (1000,)
  Mean: 19.337650299072266
  Variance: 10626.4833984375

Variable: theta
  Shape: (1000, 8)
  Mean: [-1.2455076   2.9482472  -5.281853   -0.06498987  2.3573008   0.31068215
  3.729895   -1.5531777 ]
  Variance: [11408.184   8016.4224 36956.      8181.7744  7636.8813  9586.139
  6544.6753 26078.693 ]

Variable: obs
  Shape: (1000, 8)
  Mean: [-1.7466605   2.9530644  -4.942445    0.10836948  2.5812457   0.34208587
  3.1065757  -2.218382  ]
  Variance: [11653.146   8119.2095 37439.16    8326.833   7699.3613  9800.354
  6728.329  26469.115 ]

When the samples of some variables are observed, we can condition on these observations, and infer the conditional distributions of the other variables. This process is called inference, and is commonly done using MCMC methods. The conditioning is done using the numpyro.handlers.condition primitive, by feeding it a data dict and the model.

Since we have a GPU available, we will also configure the MCMC sampler to use vectorized chains.

from numpyro.handlers import condition

mcmc_args = {
    'num_warmup': 1000,
    'num_samples': 5000,
    'num_chains': 8,
    'progress_bar': True,
    'chain_method': 'vectorized'
}

es_conditioned_0 = condition(es_0, data={'obs': y})
es_nuts_0 = NUTS(es_conditioned_0)
es_mcmc_0 = MCMC(es_nuts_0, **mcmc_args)

es_mcmc_0.run(rng_key, J, sigma)
  0%|          | 0/6000 [00:00<?, ?it/s]warmup:   0%|          | 1/6000 [00:01<2:56:25,  1.76s/it]warmup:   0%|          | 6/6000 [00:01<24:17,  4.11it/s]  warmup:   0%|          | 12/6000 [00:02<11:01,  9.05it/s]warmup:   0%|          | 17/6000 [00:02<07:24, 13.46it/s]warmup:   0%|          | 22/6000 [00:02<05:35, 17.80it/s]warmup:   0%|          | 28/6000 [00:02<04:04, 24.38it/s]warmup:   1%|          | 36/6000 [00:02<02:53, 34.33it/s]warmup:   1%|          | 45/6000 [00:02<02:10, 45.76it/s]warmup:   1%|          | 54/6000 [00:02<01:50, 53.83it/s]warmup:   1%|          | 61/6000 [00:02<01:47, 55.00it/s]warmup:   1%|          | 68/6000 [00:02<01:45, 56.44it/s]warmup:   1%|▏         | 75/6000 [00:03<01:46, 55.54it/s]warmup:   1%|▏         | 83/6000 [00:03<01:35, 61.67it/s]warmup:   2%|▏         | 90/6000 [00:03<01:35, 61.84it/s]warmup:   2%|▏         | 99/6000 [00:03<01:25, 68.81it/s]warmup:   2%|▏         | 121/6000 [00:03<00:53, 109.14it/s]warmup:   2%|▏         | 142/6000 [00:03<00:42, 136.66it/s]warmup:   3%|▎         | 161/6000 [00:03<00:38, 150.02it/s]warmup:   3%|▎         | 177/6000 [00:03<00:39, 148.57it/s]warmup:   3%|▎         | 202/6000 [00:03<00:32, 177.23it/s]warmup:   4%|▎         | 221/6000 [00:04<00:33, 174.99it/s]warmup:   4%|▍         | 244/6000 [00:04<00:30, 190.67it/s]warmup:   4%|▍         | 267/6000 [00:04<00:28, 198.94it/s]warmup:   5%|▍         | 288/6000 [00:04<00:45, 126.24it/s]warmup:   5%|▌         | 305/6000 [00:04<01:00, 93.54it/s] warmup:   5%|▌         | 318/6000 [00:04<00:58, 97.35it/s]warmup:   6%|▌         | 341/6000 [00:05<00:46, 120.61it/s]warmup:   6%|▌         | 369/6000 [00:05<00:36, 153.95it/s]warmup:   7%|▋         | 400/6000 [00:05<00:29, 189.98it/s]warmup:   7%|▋         | 433/6000 [00:05<00:24, 224.31it/s]warmup:   8%|▊         | 459/6000 [00:05<00:25, 217.87it/s]warmup:   8%|▊         | 484/6000 [00:05<00:29, 189.38it/s]warmup:   8%|▊         | 506/6000 [00:05<00:29, 184.35it/s]warmup:   9%|▉         | 537/6000 [00:05<00:25, 213.75it/s]warmup:   9%|▉         | 566/6000 [00:06<00:23, 231.53it/s]warmup:  10%|█         | 600/6000 [00:06<00:20, 259.90it/s]warmup:  11%|█         | 634/6000 [00:06<00:19, 280.08it/s]warmup:  11%|█         | 664/6000 [00:06<00:18, 284.00it/s]warmup:  12%|█▏        | 696/6000 [00:06<00:18, 294.21it/s]warmup:  12%|█▏        | 730/6000 [00:06<00:17, 306.05it/s]warmup:  13%|█▎        | 762/6000 [00:06<00:17, 305.22it/s]warmup:  13%|█▎        | 793/6000 [00:06<00:17, 293.75it/s]warmup:  14%|█▍        | 827/6000 [00:06<00:17, 302.73it/s]warmup:  14%|█▍        | 858/6000 [00:06<00:17, 301.64it/s]warmup:  15%|█▍        | 895/6000 [00:07<00:16, 318.82it/s]warmup:  15%|█▌        | 928/6000 [00:07<00:16, 314.52it/s]warmup:  16%|█▌        | 960/6000 [00:07<00:16, 310.15it/s]warmup:  17%|█▋        | 992/6000 [00:07<00:19, 260.49it/s]sample:  17%|█▋        | 1027/6000 [00:07<00:17, 282.19it/s]sample:  18%|█▊        | 1066/6000 [00:07<00:15, 309.57it/s]sample:  18%|█▊        | 1104/6000 [00:07<00:15, 325.16it/s]sample:  19%|█▉        | 1140/6000 [00:07<00:14, 334.86it/s]sample:  20%|█▉        | 1175/6000 [00:07<00:14, 336.58it/s]sample:  20%|██        | 1213/6000 [00:08<00:13, 348.33it/s]sample:  21%|██        | 1249/6000 [00:08<00:13, 344.27it/s]sample:  22%|██▏       | 1290/6000 [00:08<00:13, 361.91it/s]sample:  22%|██▏       | 1327/6000 [00:08<00:12, 360.87it/s]sample:  23%|██▎       | 1364/6000 [00:08<00:12, 359.77it/s]sample:  23%|██▎       | 1401/6000 [00:08<00:12, 361.39it/s]sample:  24%|██▍       | 1438/6000 [00:08<00:12, 360.32it/s]sample:  25%|██▍       | 1475/6000 [00:08<00:12, 362.79it/s]sample:  25%|██▌       | 1512/6000 [00:08<00:12, 354.32it/s]sample:  26%|██▌       | 1548/6000 [00:08<00:12, 352.36it/s]sample:  27%|██▋       | 1591/6000 [00:09<00:11, 373.50it/s]sample:  27%|██▋       | 1629/6000 [00:09<00:11, 374.66it/s]sample:  28%|██▊       | 1667/6000 [00:09<00:11, 370.82it/s]sample:  28%|██▊       | 1705/6000 [00:09<00:11, 369.07it/s]sample:  29%|██▉       | 1742/6000 [00:09<00:11, 362.57it/s]sample:  30%|██▉       | 1779/6000 [00:09<00:11, 356.03it/s]sample:  30%|███       | 1816/6000 [00:09<00:11, 358.07it/s]sample:  31%|███       | 1852/6000 [00:09<00:11, 353.28it/s]sample:  31%|███▏      | 1889/6000 [00:09<00:11, 355.46it/s]sample:  32%|███▏      | 1925/6000 [00:10<00:11, 349.90it/s]sample:  33%|███▎      | 1962/6000 [00:10<00:11, 354.37it/s]sample:  33%|███▎      | 1998/6000 [00:10<00:11, 353.08it/s]sample:  34%|███▍      | 2034/6000 [00:10<00:11, 354.33it/s]sample:  35%|███▍      | 2074/6000 [00:10<00:10, 367.15it/s]sample:  35%|███▌      | 2113/6000 [00:10<00:10, 372.92it/s]sample:  36%|███▌      | 2152/6000 [00:10<00:10, 376.60it/s]sample:  36%|███▋      | 2190/6000 [00:10<00:10, 369.78it/s]sample:  37%|███▋      | 2228/6000 [00:10<00:10, 367.90it/s]sample:  38%|███▊      | 2265/6000 [00:10<00:10, 360.60it/s]sample:  38%|███▊      | 2302/6000 [00:11<00:10, 345.14it/s]sample:  39%|███▉      | 2337/6000 [00:11<00:10, 346.37it/s]sample:  40%|███▉      | 2372/6000 [00:11<00:10, 343.46it/s]sample:  40%|████      | 2410/6000 [00:11<00:10, 350.77it/s]sample:  41%|████      | 2446/6000 [00:11<00:10, 351.44it/s]sample:  41%|████▏     | 2482/6000 [00:11<00:09, 352.65it/s]sample:  42%|████▏     | 2520/6000 [00:11<00:09, 359.00it/s]sample:  43%|████▎     | 2556/6000 [00:11<00:09, 356.57it/s]sample:  43%|████▎     | 2592/6000 [00:11<00:09, 354.29it/s]sample:  44%|████▍     | 2628/6000 [00:12<00:09, 354.52it/s]sample:  44%|████▍     | 2668/6000 [00:12<00:09, 365.91it/s]sample:  45%|████▌     | 2708/6000 [00:12<00:08, 375.90it/s]sample:  46%|████▌     | 2747/6000 [00:12<00:08, 377.68it/s]sample:  46%|████▋     | 2785/6000 [00:12<00:08, 377.68it/s]sample:  47%|████▋     | 2823/6000 [00:12<00:08, 367.04it/s]sample:  48%|████▊     | 2863/6000 [00:12<00:08, 374.49it/s]sample:  48%|████▊     | 2903/6000 [00:12<00:08, 381.45it/s]sample:  49%|████▉     | 2942/6000 [00:12<00:08, 375.67it/s]sample:  50%|████▉     | 2984/6000 [00:12<00:07, 386.20it/s]sample:  50%|█████     | 3023/6000 [00:13<00:07, 375.53it/s]sample:  51%|█████     | 3063/6000 [00:13<00:07, 381.42it/s]sample:  52%|█████▏    | 3102/6000 [00:13<00:08, 359.13it/s]sample:  52%|█████▏    | 3139/6000 [00:13<00:07, 359.58it/s]sample:  53%|█████▎    | 3177/6000 [00:13<00:07, 364.69it/s]sample:  54%|█████▎    | 3214/6000 [00:13<00:07, 353.20it/s]sample:  54%|█████▍    | 3250/6000 [00:13<00:07, 352.18it/s]sample:  55%|█████▍    | 3287/6000 [00:13<00:07, 356.46it/s]sample:  55%|█████▌    | 3323/6000 [00:13<00:07, 357.41it/s]sample:  56%|█████▌    | 3362/6000 [00:13<00:07, 365.88it/s]sample:  57%|█████▋    | 3399/6000 [00:14<00:07, 367.00it/s]sample:  57%|█████▋    | 3436/6000 [00:14<00:07, 339.03it/s]sample:  58%|█████▊    | 3472/6000 [00:14<00:07, 342.28it/s]sample:  58%|█████▊    | 3510/6000 [00:14<00:07, 351.65it/s]sample:  59%|█████▉    | 3551/6000 [00:14<00:06, 365.24it/s]sample:  60%|█████▉    | 3588/6000 [00:14<00:06, 363.23it/s]sample:  60%|██████    | 3625/6000 [00:14<00:06, 354.99it/s]sample:  61%|██████    | 3664/6000 [00:14<00:06, 362.97it/s]sample:  62%|██████▏   | 3701/6000 [00:14<00:06, 360.94it/s]sample:  62%|██████▏   | 3738/6000 [00:15<00:06, 362.44it/s]sample:  63%|██████▎   | 3775/6000 [00:15<00:06, 361.03it/s]sample:  64%|██████▎   | 3813/6000 [00:15<00:05, 365.52it/s]sample:  64%|██████▍   | 3851/6000 [00:15<00:05, 369.20it/s]sample:  65%|██████▍   | 3889/6000 [00:15<00:05, 370.49it/s]sample:  65%|██████▌   | 3927/6000 [00:15<00:05, 370.24it/s]sample:  66%|██████▌   | 3965/6000 [00:15<00:05, 364.54it/s]sample:  67%|██████▋   | 4003/6000 [00:15<00:05, 367.68it/s]sample:  67%|██████▋   | 4040/6000 [00:15<00:05, 363.75it/s]sample:  68%|██████▊   | 4077/6000 [00:15<00:05, 363.29it/s]sample:  69%|██████▊   | 4114/6000 [00:16<00:05, 363.71it/s]sample:  69%|██████▉   | 4151/6000 [00:16<00:05, 359.68it/s]sample:  70%|██████▉   | 4190/6000 [00:16<00:04, 365.62it/s]sample:  71%|███████   | 4231/6000 [00:16<00:04, 375.54it/s]sample:  71%|███████   | 4269/6000 [00:16<00:04, 368.07it/s]sample:  72%|███████▏  | 4306/6000 [00:16<00:04, 359.10it/s]sample:  72%|███████▏  | 4342/6000 [00:16<00:04, 358.88it/s]sample:  73%|███████▎  | 4378/6000 [00:16<00:04, 358.51it/s]sample:  74%|███████▎  | 4414/6000 [00:16<00:04, 349.85it/s]sample:  74%|███████▍  | 4450/6000 [00:17<00:04, 344.07it/s]sample:  75%|███████▍  | 4485/6000 [00:17<00:04, 344.59it/s]sample:  75%|███████▌  | 4523/6000 [00:17<00:04, 354.43it/s]sample:  76%|███████▌  | 4559/6000 [00:17<00:04, 351.98it/s]sample:  77%|███████▋  | 4597/6000 [00:17<00:03, 359.29it/s]sample:  77%|███████▋  | 4634/6000 [00:17<00:03, 361.06it/s]sample:  78%|███████▊  | 4671/6000 [00:17<00:03, 358.32it/s]sample:  78%|███████▊  | 4707/6000 [00:17<00:03, 354.99it/s]sample:  79%|███████▉  | 4745/6000 [00:17<00:03, 358.24it/s]sample:  80%|███████▉  | 4785/6000 [00:17<00:03, 368.12it/s]sample:  80%|████████  | 4825/6000 [00:18<00:03, 374.32it/s]sample:  81%|████████  | 4863/6000 [00:18<00:03, 373.06it/s]sample:  82%|████████▏ | 4901/6000 [00:18<00:03, 357.02it/s]sample:  82%|████████▏ | 4937/6000 [00:18<00:03, 345.55it/s]sample:  83%|████████▎ | 4973/6000 [00:18<00:02, 347.70it/s]sample:  84%|████████▎ | 5013/6000 [00:18<00:02, 360.43it/s]sample:  84%|████████▍ | 5050/6000 [00:18<00:02, 357.37it/s]sample:  85%|████████▍ | 5091/6000 [00:18<00:02, 371.87it/s]sample:  86%|████████▌ | 5130/6000 [00:18<00:02, 374.76it/s]sample:  86%|████████▌ | 5168/6000 [00:18<00:02, 370.66it/s]sample:  87%|████████▋ | 5207/6000 [00:19<00:02, 375.37it/s]sample:  87%|████████▋ | 5245/6000 [00:19<00:02, 356.83it/s]sample:  88%|████████▊ | 5281/6000 [00:19<00:02, 354.05it/s]sample:  89%|████████▊ | 5319/6000 [00:19<00:01, 361.06it/s]sample:  89%|████████▉ | 5359/6000 [00:19<00:01, 369.43it/s]sample:  90%|████████▉ | 5397/6000 [00:19<00:01, 371.74it/s]sample:  91%|█████████ | 5436/6000 [00:19<00:01, 375.16it/s]sample:  91%|█████████▏| 5476/6000 [00:19<00:01, 382.36it/s]sample:  92%|█████████▏| 5515/6000 [00:19<00:01, 381.99it/s]sample:  93%|█████████▎| 5554/6000 [00:20<00:01, 379.83it/s]sample:  93%|█████████▎| 5593/6000 [00:20<00:01, 371.27it/s]sample:  94%|█████████▍| 5631/6000 [00:20<00:01, 365.11it/s]sample:  94%|█████████▍| 5668/6000 [00:20<00:00, 365.91it/s]sample:  95%|█████████▌| 5705/6000 [00:20<00:00, 355.23it/s]sample:  96%|█████████▌| 5742/6000 [00:20<00:00, 355.68it/s]sample:  96%|█████████▋| 5779/6000 [00:20<00:00, 357.38it/s]sample:  97%|█████████▋| 5818/6000 [00:20<00:00, 365.50it/s]sample:  98%|█████████▊| 5855/6000 [00:20<00:00, 357.41it/s]sample:  98%|█████████▊| 5893/6000 [00:20<00:00, 360.50it/s]sample:  99%|█████████▉| 5930/6000 [00:21<00:00, 360.57it/s]sample:  99%|█████████▉| 5967/6000 [00:21<00:00, 356.93it/s]sample: 100%|██████████| 6000/6000 [00:21<00:00, 282.02it/s]

The NUTS sampler is a variant of the Hamiltonian Monte Carlo (HMC) sampler, which is a powerful tool for sampling from complex probability distributions. HMC also comes with its own set of diagnostics, which can be used to check the convergence of the Markov Chain. The most important ones are the effective sample size and the Gelman-Rubin statistic, which is a measure of the convergence of the Markov Chain.

es_mcmc_0.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.26      3.34      4.25     -1.24      9.72   3163.83      1.00
       tau      3.91      3.17      3.03      0.45      8.08   2151.46      1.01
  theta[0]      6.26      5.81      5.64     -2.93     15.04   5795.74      1.00
  theta[1]      4.86      4.78      4.74     -2.97     12.48   6468.43      1.00
  theta[2]      3.71      5.40      3.97     -4.63     12.37   6937.61      1.00
  theta[3]      4.66      4.94      4.55     -3.62     12.34   6752.22      1.00
  theta[4]      3.41      4.71      3.64     -4.15     10.92   5415.63      1.00
  theta[5]      3.86      4.97      4.01     -4.12     11.83   6923.98      1.00
  theta[6]      6.36      5.27      5.75     -2.32     14.38   4727.40      1.00
  theta[7]      4.72      5.50      4.59     -3.78     13.50   7525.65      1.00

Number of divergences: 1327

The effective sample size for \(\tau\) is low, and judging from the r_hat value, the Markov Chain might not have converged. Also, the large number of divergences is a sign that the model might not be well specified. Here we are looking at a prominent problem in hierarchical modeling, known as Radford’s funnel, where the posterior distribution has a very sharp peak at the center, and a long tail, and the curvature of the distribution is very high. This seriously hinders the performance of HMC, and the divergences are a sign that the sampler is having trouble exploring the space.

However, the issue can be readily rectified using a non-centered parameterization.

Manual reparameterisation

The remedy we are proposing is quite simple: replacing

\[ \theta_n \sim \text{N} (\mu, \tau) \]

with

\[ \begin{align*} \theta_n &= \mu + \tau \theta_0 \\ \theta_0 &\sim \text{N} (0, 1). \end{align*} \]

In essence, instead of drawing from a Normal distribution whose parameters are themselves variables in the model, we draw from the unit Normal distribution, and transform it to get the variable we want. By doing so we untangled the sampling process of \(\theta\) from that of \(\mu\) and \(\tau\).

def es_1(J, sigma):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))

    with numpyro.plate('J', J):
        theta_0 = numpyro.sample('theta_0', dist.Normal(0, 1))
        theta = numpyro.deterministic('theta', mu + theta_0 * tau)
        numpyro.sample('obs', dist.Normal(theta, sigma))

Here we use another primitive, numpyro.deterministic, to register the transformed variable, so that its values can be stored and used later.

es_prior_predictive_1 = Predictive(es_1, num_samples=1000)
es_prior_samples_1 = es_prior_predictive_1(rng_key, J, sigma)

print_stats('theta_0', es_prior_samples_1['theta_0'])
Variable: theta_0
  Shape: (1000, 8)
  Mean: [ 0.035466    0.00418693  0.04043639 -0.01104304  0.03547993  0.0202828
  0.01902334  0.01258929]
  Variance: [1.0026257 0.9806318 0.9726549 0.9744275 0.9668162 0.9936419 1.0036482
 1.0627806]

Condition on the observed data and do inference.

es_conditioned_1 = condition(es_1, data={'obs': y})
es_nuts_1 = NUTS(es_conditioned_1)
es_mcmc_1 = MCMC(es_nuts_1, **mcmc_args)

es_mcmc_1.run(rng_key, J, sigma)
es_mcmc_1.print_summary(exclude_deterministic=False)
  0%|          | 0/6000 [00:00<?, ?it/s]warmup:   0%|          | 1/6000 [00:01<2:49:21,  1.69s/it]warmup:   1%|          | 40/6000 [00:01<03:13, 30.87it/s] warmup:   1%|▏         | 89/6000 [00:01<01:17, 76.66it/s]warmup:   2%|▏         | 130/6000 [00:01<00:49, 117.99it/s]warmup:   3%|▎         | 168/6000 [00:02<00:37, 157.27it/s]warmup:   4%|▎         | 219/6000 [00:02<00:26, 219.49it/s]warmup:   5%|▍         | 271/6000 [00:02<00:20, 278.86it/s]warmup:   5%|▌         | 316/6000 [00:02<00:18, 313.88it/s]warmup:   6%|▋         | 375/6000 [00:02<00:14, 378.05it/s]warmup:   7%|▋         | 438/6000 [00:02<00:12, 441.77it/s]warmup:   8%|▊         | 491/6000 [00:02<00:12, 441.30it/s]warmup:   9%|▉         | 543/6000 [00:02<00:11, 461.49it/s]warmup:  10%|█         | 606/6000 [00:02<00:10, 506.33it/s]warmup:  11%|█▏        | 679/6000 [00:03<00:09, 568.96it/s]warmup:  13%|█▎        | 756/6000 [00:03<00:08, 625.37it/s]warmup:  14%|█▍        | 829/6000 [00:03<00:07, 655.46it/s]warmup:  15%|█▌        | 900/6000 [00:03<00:07, 670.43it/s]warmup:  16%|█▌        | 969/6000 [00:03<00:07, 644.71it/s]sample:  17%|█▋        | 1035/6000 [00:03<00:08, 597.46it/s]sample:  18%|█▊        | 1099/6000 [00:03<00:08, 607.09it/s]sample:  19%|█▉        | 1168/6000 [00:03<00:07, 629.50it/s]sample:  21%|██        | 1235/6000 [00:03<00:07, 639.45it/s]sample:  22%|██▏       | 1305/6000 [00:03<00:07, 654.66it/s]sample:  23%|██▎       | 1371/6000 [00:04<00:07, 650.03it/s]sample:  24%|██▍       | 1439/6000 [00:04<00:06, 658.12it/s]sample:  25%|██▌       | 1506/6000 [00:04<00:06, 656.97it/s]sample:  26%|██▋       | 1575/6000 [00:04<00:06, 665.20it/s]sample:  27%|██▋       | 1644/6000 [00:04<00:06, 670.59it/s]sample:  29%|██▊       | 1714/6000 [00:04<00:06, 676.79it/s]sample:  30%|██▉       | 1782/6000 [00:04<00:06, 667.70it/s]sample:  31%|███       | 1849/6000 [00:04<00:06, 665.51it/s]sample:  32%|███▏      | 1919/6000 [00:04<00:06, 674.07it/s]sample:  33%|███▎      | 1991/6000 [00:04<00:05, 686.42it/s]sample:  34%|███▍      | 2060/6000 [00:05<00:05, 682.15it/s]sample:  36%|███▌      | 2132/6000 [00:05<00:05, 692.86it/s]sample:  37%|███▋      | 2205/6000 [00:05<00:05, 703.76it/s]sample:  38%|███▊      | 2277/6000 [00:05<00:05, 707.33it/s]sample:  39%|███▉      | 2348/6000 [00:05<00:05, 687.87it/s]sample:  40%|████      | 2417/6000 [00:05<00:05, 685.97it/s]sample:  41%|████▏     | 2489/6000 [00:05<00:05, 693.33it/s]sample:  43%|████▎     | 2560/6000 [00:05<00:04, 697.98it/s]sample:  44%|████▍     | 2630/6000 [00:05<00:04, 693.00it/s]sample:  45%|████▌     | 2701/6000 [00:06<00:04, 696.82it/s]sample:  46%|████▌     | 2771/6000 [00:06<00:04, 694.22it/s]sample:  47%|████▋     | 2841/6000 [00:06<00:04, 695.12it/s]sample:  49%|████▊     | 2911/6000 [00:06<00:04, 693.80it/s]sample:  50%|████▉     | 2981/6000 [00:06<00:04, 686.49it/s]sample:  51%|█████     | 3051/6000 [00:06<00:04, 690.30it/s]sample:  52%|█████▏    | 3121/6000 [00:06<00:04, 682.51it/s]sample:  53%|█████▎    | 3190/6000 [00:06<00:04, 677.32it/s]sample:  54%|█████▍    | 3260/6000 [00:06<00:04, 683.92it/s]sample:  56%|█████▌    | 3332/6000 [00:06<00:03, 692.53it/s]sample:  57%|█████▋    | 3402/6000 [00:07<00:03, 679.85it/s]sample:  58%|█████▊    | 3471/6000 [00:07<00:03, 664.04it/s]sample:  59%|█████▉    | 3542/6000 [00:07<00:03, 677.12it/s]sample:  60%|██████    | 3611/6000 [00:07<00:03, 678.11it/s]sample:  61%|██████▏   | 3679/6000 [00:07<00:03, 674.24it/s]sample:  62%|██████▏   | 3749/6000 [00:07<00:03, 679.44it/s]sample:  64%|██████▎   | 3817/6000 [00:07<00:03, 674.91it/s]sample:  65%|██████▍   | 3886/6000 [00:07<00:03, 676.62it/s]sample:  66%|██████▌   | 3959/6000 [00:07<00:02, 688.60it/s]sample:  67%|██████▋   | 4030/6000 [00:07<00:02, 692.98it/s]sample:  68%|██████▊   | 4102/6000 [00:08<00:02, 700.35it/s]sample:  70%|██████▉   | 4173/6000 [00:08<00:02, 691.62it/s]sample:  71%|███████   | 4243/6000 [00:08<00:02, 683.16it/s]sample:  72%|███████▏  | 4312/6000 [00:08<00:02, 676.49it/s]sample:  73%|███████▎  | 4380/6000 [00:08<00:02, 675.39it/s]sample:  74%|███████▍  | 4448/6000 [00:08<00:02, 672.14it/s]sample:  75%|███████▌  | 4516/6000 [00:08<00:02, 672.23it/s]sample:  76%|███████▋  | 4584/6000 [00:08<00:02, 667.97it/s]sample:  78%|███████▊  | 4654/6000 [00:08<00:01, 676.12it/s]sample:  79%|███████▉  | 4728/6000 [00:08<00:01, 694.89it/s]sample:  80%|████████  | 4802/6000 [00:09<00:01, 706.84it/s]sample:  81%|████████  | 4873/6000 [00:09<00:01, 684.89it/s]sample:  82%|████████▏ | 4943/6000 [00:09<00:01, 688.55it/s]sample:  84%|████████▎ | 5012/6000 [00:09<00:01, 681.95it/s]sample:  85%|████████▍ | 5081/6000 [00:09<00:01, 667.94it/s]sample:  86%|████████▌ | 5152/6000 [00:09<00:01, 680.12it/s]sample:  87%|████████▋ | 5225/6000 [00:09<00:01, 691.18it/s]sample:  88%|████████▊ | 5295/6000 [00:09<00:01, 685.10it/s]sample:  89%|████████▉ | 5364/6000 [00:09<00:00, 685.47it/s]sample:  91%|█████████ | 5433/6000 [00:10<00:00, 685.31it/s]sample:  92%|█████████▏| 5503/6000 [00:10<00:00, 687.00it/s]sample:  93%|█████████▎| 5572/6000 [00:10<00:00, 679.74it/s]sample:  94%|█████████▍| 5641/6000 [00:10<00:00, 670.24it/s]sample:  95%|█████████▌| 5709/6000 [00:10<00:00, 673.05it/s]sample:  96%|█████████▋| 5782/6000 [00:10<00:00, 687.45it/s]sample:  98%|█████████▊| 5854/6000 [00:10<00:00, 695.72it/s]sample:  99%|█████████▊| 5924/6000 [00:10<00:00, 677.24it/s]sample: 100%|█████████▉| 5993/6000 [00:10<00:00, 679.18it/s]sample: 100%|██████████| 6000/6000 [00:10<00:00, 552.68it/s]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.39      3.30      4.40     -1.00      9.83  35585.46      1.00
       tau      3.62      3.23      2.77      0.00      7.78  28589.16      1.00
  theta[0]      6.23      5.57      5.66     -2.33     14.72  35893.88      1.00
  theta[1]      4.91      4.64      4.80     -2.49     12.31  42347.88      1.00
  theta[2]      3.92      5.30      4.16     -4.44     12.15  34873.20      1.00
  theta[3]      4.76      4.77      4.72     -2.90     12.33  42643.90      1.00
  theta[4]      3.61      4.66      3.86     -3.58     11.25  39974.71      1.00
  theta[5]      4.08      4.77      4.21     -3.52     11.68  39644.57      1.00
  theta[6]      6.28      5.07      5.79     -1.62     14.28  40117.39      1.00
  theta[7]      4.85      5.25      4.76     -3.41     13.11  37310.81      1.00
theta_0[0]      0.32      0.99      0.34     -1.27      1.97  40503.70      1.00
theta_0[1]      0.10      0.93      0.10     -1.41      1.65  43437.47      1.00
theta_0[2]     -0.08      0.97     -0.08     -1.71      1.47  41132.58      1.00
theta_0[3]      0.06      0.95      0.06     -1.55      1.58  45070.05      1.00
theta_0[4]     -0.17      0.93     -0.17     -1.69      1.36  40592.80      1.00
theta_0[5]     -0.07      0.94     -0.07     -1.61      1.48  44577.57      1.00
theta_0[6]      0.35      0.96      0.36     -1.16      2.01  38809.45      1.00
theta_0[7]      0.08      0.98      0.08     -1.58      1.63  45373.39      1.00

Number of divergences: 5

This looks much better, both the effective sample size and the r_hat have massively improved for the hyperparameter tau, and the number of divergences is also much lower. However, the fact that there are still divergences tells us that reparameterisation might improve the topology of the posterior parameter space, but there is no guarantee that it will completely eliminate the problem. When doing Bayesian inference, especially with models of complex dependency relationships as in hierarchical models, good techniques are never a sufficient replacement for good thinking.

Using numpyro’s reparameterisation handler

Since this reparameterisation is so widely used, it has already been implemented in NumPyro. And since reparameterisation in general is so important in probabilistic modelling, NumPyro has implemented a wide suite of them.

In probabilistic modeling, although it’s always a good practice to separate modeling from inference, it’s not always easy to do so. As we have seen, how we formulate the model can have a significant impact on the inference performance. When building the model, not only do we need to configure the variable transformations, but we also need to inform the inference engine how to handle these transformed variables. This is where the numpyro.handlers.reparam handler comes in.

from numpyro.handlers import reparam
from numpyro.infer.reparam import TransformReparam
from numpyro.distributions.transforms import AffineTransform
from numpyro.distributions import TransformedDistribution

def es_2(J, sigma):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))

    with numpyro.plate('J', J):
        with reparam(config={'theta': TransformReparam()}):
            theta = numpyro.sample(
                'theta',
                TransformedDistribution(dist.Normal(0., 1.), AffineTransform(mu, tau)))
        numpyro.sample('obs', dist.Normal(theta, sigma))

The process of reparameterisation goes as follows:

  1. Start with a standard Normal distribution dist.Normal(0., 1.),
  2. Transform it using the affine transformation AffineTransform(mu, tau),
  3. Denote the result as a TransformedDistribution,
  4. Register the transformed variable theta using numpyro.sample,
  5. Inform the inference engine of the reparameterisation using reparam.

Proceed with prior predictive sampling.

es_prior_predictive_2 = Predictive(es_2, num_samples=1000)
es_prior_samples_2 = es_prior_predictive_2(rng_key, J, sigma)

print_stats('theta', es_prior_samples_2['theta'])
Variable: theta
  Shape: (1000, 8)
  Mean: [-1.2455076   2.9482472  -5.281853   -0.06498987  2.3573008   0.31068215
  3.729895   -1.5531777 ]
  Variance: [11408.184   8016.4224 36956.      8181.7744  7636.8813  9586.139
  6544.6753 26078.693 ]

Condition on the observed data and do inference.

es_conditioned_2 = condition(es_2, data={'obs': y})
es_nuts_2 = NUTS(es_conditioned_2)
es_mcmc_2 = MCMC(es_nuts_2, **mcmc_args)

es_mcmc_2.run(rng_key, J, sigma)
es_mcmc_2.print_summary()
  0%|          | 0/6000 [00:00<?, ?it/s]warmup:   0%|          | 1/6000 [00:01<2:58:45,  1.79s/it]warmup:   1%|          | 40/6000 [00:01<03:22, 29.37it/s] warmup:   2%|▏         | 90/6000 [00:01<01:19, 74.28it/s]warmup:   2%|▏         | 135/6000 [00:02<00:49, 119.09it/s]warmup:   3%|▎         | 174/6000 [00:02<00:36, 157.82it/s]warmup:   4%|▍         | 229/6000 [00:02<00:25, 224.14it/s]warmup:   5%|▍         | 276/6000 [00:02<00:21, 272.23it/s]warmup:   5%|▌         | 324/6000 [00:02<00:17, 316.78it/s]warmup:   6%|▋         | 386/6000 [00:02<00:14, 388.23it/s]warmup:   8%|▊         | 450/6000 [00:02<00:12, 450.84it/s]warmup:   8%|▊         | 505/6000 [00:02<00:12, 444.11it/s]warmup:   9%|▉         | 561/6000 [00:02<00:11, 473.22it/s]warmup:  10%|█         | 627/6000 [00:03<00:10, 522.86it/s]warmup:  12%|█▏        | 703/6000 [00:03<00:09, 588.10it/s]warmup:  13%|█▎        | 777/6000 [00:03<00:08, 630.63it/s]warmup:  14%|█▍        | 850/6000 [00:03<00:07, 658.48it/s]warmup:  15%|█▌        | 922/6000 [00:03<00:07, 673.72it/s]warmup:  17%|█▋        | 991/6000 [00:03<00:08, 602.15it/s]sample:  18%|█▊        | 1058/6000 [00:03<00:08, 617.55it/s]sample:  19%|█▊        | 1122/6000 [00:03<00:07, 622.47it/s]sample:  20%|█▉        | 1191/6000 [00:03<00:07, 637.69it/s]sample:  21%|██        | 1258/6000 [00:03<00:07, 645.94it/s]sample:  22%|██▏       | 1326/6000 [00:04<00:07, 653.96it/s]sample:  23%|██▎       | 1394/6000 [00:04<00:06, 661.26it/s]sample:  24%|██▍       | 1461/6000 [00:04<00:06, 657.73it/s]sample:  26%|██▌       | 1530/6000 [00:04<00:06, 666.22it/s]sample:  27%|██▋       | 1597/6000 [00:04<00:06, 665.52it/s]sample:  28%|██▊       | 1670/6000 [00:04<00:06, 682.62it/s]sample:  29%|██▉       | 1739/6000 [00:04<00:06, 677.93it/s]sample:  30%|███       | 1807/6000 [00:04<00:06, 662.44it/s]sample:  31%|███▏      | 1876/6000 [00:04<00:06, 670.32it/s]sample:  32%|███▏      | 1946/6000 [00:05<00:05, 678.70it/s]sample:  34%|███▎      | 2014/6000 [00:05<00:05, 678.44it/s]sample:  35%|███▍      | 2084/6000 [00:05<00:05, 683.93it/s]sample:  36%|███▌      | 2156/6000 [00:05<00:05, 693.12it/s]sample:  37%|███▋      | 2231/6000 [00:05<00:05, 707.72it/s]sample:  38%|███▊      | 2302/6000 [00:05<00:05, 695.40it/s]sample:  40%|███▉      | 2372/6000 [00:05<00:05, 691.40it/s]sample:  41%|████      | 2442/6000 [00:05<00:05, 689.02it/s]sample:  42%|████▏     | 2511/6000 [00:05<00:05, 688.74it/s]sample:  43%|████▎     | 2581/6000 [00:05<00:04, 689.30it/s]sample:  44%|████▍     | 2650/6000 [00:06<00:04, 689.04it/s]sample:  45%|████▌     | 2720/6000 [00:06<00:04, 689.86it/s]sample:  46%|████▋     | 2789/6000 [00:06<00:04, 687.98it/s]sample:  48%|████▊     | 2858/6000 [00:06<00:04, 683.87it/s]sample:  49%|████▉     | 2927/6000 [00:06<00:04, 678.14it/s]sample:  50%|████▉     | 2995/6000 [00:06<00:04, 675.40it/s]sample:  51%|█████     | 3063/6000 [00:06<00:04, 674.69it/s]sample:  52%|█████▏    | 3134/6000 [00:06<00:04, 683.40it/s]sample:  53%|█████▎    | 3203/6000 [00:06<00:04, 680.60it/s]sample:  55%|█████▍    | 3273/6000 [00:06<00:03, 686.28it/s]sample:  56%|█████▌    | 3344/6000 [00:07<00:03, 691.04it/s]sample:  57%|█████▋    | 3414/6000 [00:07<00:03, 676.40it/s]sample:  58%|█████▊    | 3482/6000 [00:07<00:03, 665.85it/s]sample:  59%|█████▉    | 3551/6000 [00:07<00:03, 672.72it/s]sample:  60%|██████    | 3619/6000 [00:07<00:03, 672.73it/s]sample:  61%|██████▏   | 3687/6000 [00:07<00:03, 668.20it/s]sample:  63%|██████▎   | 3757/6000 [00:07<00:03, 675.10it/s]sample:  64%|██████▍   | 3825/6000 [00:07<00:03, 673.79it/s]sample:  65%|██████▍   | 3894/6000 [00:07<00:03, 676.31it/s]sample:  66%|██████▌   | 3965/6000 [00:07<00:02, 685.94it/s]sample:  67%|██████▋   | 4035/6000 [00:08<00:02, 687.12it/s]sample:  68%|██████▊   | 4106/6000 [00:08<00:02, 693.19it/s]sample:  70%|██████▉   | 4176/6000 [00:08<00:02, 687.88it/s]sample:  71%|███████   | 4245/6000 [00:08<00:02, 675.05it/s]sample:  72%|███████▏  | 4313/6000 [00:08<00:02, 670.00it/s]sample:  73%|███████▎  | 4381/6000 [00:08<00:02, 672.18it/s]sample:  74%|███████▍  | 4449/6000 [00:08<00:02, 669.00it/s]sample:  75%|███████▌  | 4517/6000 [00:08<00:02, 666.91it/s]sample:  76%|███████▋  | 4584/6000 [00:08<00:02, 661.05it/s]sample:  78%|███████▊  | 4654/6000 [00:08<00:02, 670.89it/s]sample:  79%|███████▉  | 4728/6000 [00:09<00:01, 689.69it/s]sample:  80%|████████  | 4801/6000 [00:09<00:01, 701.25it/s]sample:  81%|████████  | 4872/6000 [00:09<00:01, 686.39it/s]sample:  82%|████████▏ | 4942/6000 [00:09<00:01, 690.31it/s]sample:  84%|████████▎ | 5012/6000 [00:09<00:01, 680.57it/s]sample:  85%|████████▍ | 5081/6000 [00:09<00:01, 676.28it/s]sample:  86%|████████▌ | 5152/6000 [00:09<00:01, 685.22it/s]sample:  87%|████████▋ | 5224/6000 [00:09<00:01, 693.68it/s]sample:  88%|████████▊ | 5294/6000 [00:09<00:01, 682.65it/s]sample:  89%|████████▉ | 5363/6000 [00:10<00:00, 684.07it/s]sample:  91%|█████████ | 5432/6000 [00:10<00:00, 682.21it/s]sample:  92%|█████████▏| 5501/6000 [00:10<00:00, 680.31it/s]sample:  93%|█████████▎| 5570/6000 [00:10<00:00, 679.17it/s]sample:  94%|█████████▍| 5638/6000 [00:10<00:00, 666.87it/s]sample:  95%|█████████▌| 5706/6000 [00:10<00:00, 669.51it/s]sample:  96%|█████████▋| 5778/6000 [00:10<00:00, 683.00it/s]sample:  97%|█████████▋| 5848/6000 [00:10<00:00, 687.24it/s]sample:  99%|█████████▊| 5917/6000 [00:10<00:00, 669.91it/s]sample: 100%|█████████▉| 5986/6000 [00:10<00:00, 672.94it/s]sample: 100%|██████████| 6000/6000 [00:10<00:00, 547.77it/s]

                   mean       std    median      5.0%     95.0%     n_eff     r_hat
           mu      4.39      3.30      4.40     -1.00      9.83  35585.46      1.00
          tau      3.62      3.23      2.77      0.00      7.78  28589.16      1.00
theta_base[0]      0.32      0.99      0.34     -1.27      1.97  40503.70      1.00
theta_base[1]      0.10      0.93      0.10     -1.41      1.65  43437.47      1.00
theta_base[2]     -0.08      0.97     -0.08     -1.71      1.47  41132.58      1.00
theta_base[3]      0.06      0.95      0.06     -1.55      1.58  45070.05      1.00
theta_base[4]     -0.17      0.93     -0.17     -1.69      1.36  40592.80      1.00
theta_base[5]     -0.07      0.94     -0.07     -1.61      1.48  44577.57      1.00
theta_base[6]      0.35      0.96      0.36     -1.16      2.01  38809.45      1.00
theta_base[7]      0.08      0.98      0.08     -1.58      1.63  45373.39      1.00

Number of divergences: 5

The results are consistent with the manual reparameterisation. It might seem uncessarily complicated to use the reparameterisation handler in this simple example, but in more complex models, especially those with many layers of dependencies, the reparameterisation handler can greatly facilitate the model building process.