STAN Introduction

Joseph Sartini

2025-01-31

Bayesian Modeling

\[P(\Theta|X) \propto P(X|\Theta)P(\Theta)\]

https://svmiller.com/blog/2021/02/thinking-about-your-priors-bayesian-analysis/

Markov Chain Monte Carlo

  • Metropolis Hastings (MH)

    • Proposal distribution

    • Acceptance probability

  • Gibbs Sampling

    • Special case of MH

    • Conditional distributions

    • \(P\)(Acceptance)\(=1\)

  • Hamiltonian Monte Carlo

    • Special case of MH

    • Hamiltonian dynamics

    • \(P\)(Acceptance)\(=1\)*

Hamiltonian Monte Carlo

HMC Visualization: By Justinkunimune - github.com/jkunimune/hamiltonian-mc, CC0

Hamiltonian Monte Carlo Continued

  • MH with Hamiltonian Dynamics to propose movements

    • Randomly sample kinetic energy

    • Posterior \(\approx\) potential energy field

    • Simulate trajectory using leapfrog integrator

    • Accept/reject proposed stopping point

  • Lower correlation between samples

    • Physical model: travel further in the parameter space

    • Higher sampling efficiency

  • Energy conservation: high acceptance probability

What is STAN?

  • Bayesian probabilistic programming language

  • Multiple posterior sampling routines

    • Hamiltonian Monte Carlo

      • Adaptive step size
    • Variational Inference

    • Laplace approximation

  • Based on C++

  • Interfaces with Python, Julia, R, and Unix Shell

Structure of a STAN Script

functions {
  // ... function declarations and definitions ...
}
data {
  // ... declarations ...
}
transformed data {
   // ... declarations ... statements ...
}
parameters {
   // ... declarations ...
}
transformed parameters {
   // ... declarations ... statements ...
}
model {
   // ... declarations ... statements ...
}
generated quantities {
   // ... declarations ... statements ...
}

Section - functions

  • Repeated sub-tasks

    • Code brevity

    • Parallelization

  • Complex indexing

    • Sparsely observed data
  • Suffixes for particular functions

    • Containing RNG: “_rng”

    • Modifying target density: “_lp”

Section - data

  • Observed data

    • Corresponding indexing objects
  • All constants

    • Array extents
  • Commonly used linear transforms

Section - transformed data

  • Functions of data variables

    • e.g., eigenvalues of matrix
  • Helpful for book-keeping

    • Simplify data inputs
  • Only evaluated once

    • Prior to sampling

Section - parameters

  • Specify sampled quantities

    • Variable names

    • Any constraints

  • Definitions only
  • Can provide initial values

Section - transformed parameters

  • Deterministic functions of parameters
  • Good for re-parameterization

    • Parameter expansion

    • Centered / Non-centered

  • Evaluated with each sample

    • Transformation vs. Change of variables

    • Inverse transform and Log absolute Jacobian

Section - model

  • Define the target posterior

    • Log density
  • Prior distributions on (transformed) parameters

  • Data/model likelihood

  • Most computational expense

  • ORDER MATTERS

Section - generated quantities

  • Functions of model output

    • Predictions for new data

    • Simulations based on parameters

    • Extract posterior estimates

    • Calculate model fit criterion

  • Executed after samples are generated

Example Model - Simple GLM

data {
  int N;  // Number of observations
  int P;  // Number of fixed effect covariates
  
  array[N] int<lower=0, upper=1> Y;  // Binary outcomes
  matrix[N, P] X;                    // Fixed effects design matrix
}
parameters {
   vector[P] beta;  // Coefficients
}
model {
   Y ~ bernoulli_logit(X * beta);
   // target += bernoulli_logit_lpmf(Y | X * beta);
}

Running the Model in R

fit_df = mtcars %>%
  mutate(Efficient = case_when(mpg >= median(mpg) ~ 1,
                               TRUE ~ 0)) %>%
  mutate(am = as.factor(am))
fit_matrix = model.matrix(~cyl + disp + hp + drat + wt + am, fit_df)

data_list = list(N = nrow(mtcars), P = ncol(fit_matrix), 
                 Y = fit_df$Efficient, X = fit_matrix)

model = sampling(
  first_model, 
  data = data_list, 
  chains = 4, 
  iter = 1000, 
  warmup = 500, 
  # init = ,
  # control = list(...), 
  verbose = F,
  refresh = 0
)

Reminder: Hamiltonian Dynamics

HMC Visualization: By Justinkunimune - github.com/jkunimune/hamiltonian-mc, CC0

Runtime Options - Adaptive HMC

  • adapt_delta: target acceptance probability in adaptation

  • max_treedepth: bounds leapfrog steps

  • stepsize_jitter: multiplier of adapted stepsize

Diagnostics

check_hmc_diagnostics(model)

Divergences:
0 of 2000 iterations ended with a divergence.

Tree depth:
1192 of 2000 iterations saturated the maximum tree depth of 10 (59.6%).
Try increasing 'max_treedepth' to avoid saturation.

Energy:
E-BFMI indicated no pathological behavior.
summary(model)$summary[,"Rhat"]
 beta[1]  beta[2]  beta[3]  beta[4]  beta[5]  beta[6]  beta[7]     lp__ 
1.905472 1.636913 1.918515 1.921832 1.923458 1.429901 1.929289 1.018607 

Divergences

  • Simulated trajectory \(\neq\) true trajectory
  • Step size \(>\) true posterior geometry resolution

    • Leapfrog first order approximation
  • Hamiltonian departs from initial value

    • Total energy (kinetic + potential)

    • Should be preserved along trajectory

  • Sampler may reject samples after divergence

Tree Depth Warnings

  • Tree depth controls number of simulation steps

    • \(\leq 2^{max\_treedepth}\) steps

    • Keeps simulated trajectories finite

  • Primarily an efficiency concern

  • Generally recommended to not increase

    • Often model misspecification

Est. Bayesian Fraction of Miss. Info.

  • Posterior decomposes into energy equivalence classes
  • Low EBFMI indicates getting “stuck” in energy sets

    • STAN monitors the Hamiltonian during sampling

    • Chosen kinetic energies do not deviate enough

  • Insufficiently exploring the posterior

    • Tails too large, etc

Geometric Intuition

Return to our Example: Model Outputs

samples = extract(model)

beta_0 = map(1:dim(samples$beta)[1], function(x){
  return(data.frame(beta0 = samples$beta[x,1], 
                    Sample = x))
}) %>% list_rbind()

Model Outputs (2)

samples = extract(model, "beta", permuted = FALSE)

beta_0 = map(1:dim(samples)[1], function(x){
  return(data.frame(beta0 = samples[x,,1], 
                    Chain = 1:4, 
                    sample = x))
}) %>% list_rbind()

ShinySTAN Debugging

launch_shinystan(model)

ShinySTAN - Summary

ShinySTAN - Divergences

ShinySTAN - Treedepth

ShinySTAN - Energy

ShinySTAN - Autocorrelation

How to Update the Model

data {
  int N;  // Number of observations
  int P;  // Number of fixed effect covariates
  
  array[N] int<lower=0, upper=1> Y;  // Binary outcomes
  matrix[N, P] X;                    // Fixed effects design matrix
}
transformed data {
  matrix[N, P] Q_coef = qr_thin_Q(X) * sqrt(N-1);
  matrix[P, P] R_coef = qr_thin_R(X) / sqrt(N-1);
}
parameters {
  vector[P] beta;  // Coefficients
}
transformed parameters{
  vector[P] theta = R_coef * beta;
}
model {
  beta ~ normal(0, 10);
  Y ~ bernoulli_logit(Q_coef * theta);
}

Running the Updated Model

model = sampling(
  second_model, 
  data = data_list, 
  chains = 4, 
  iter = 1000, 
  warmup = 500, 
  verbose = F,
  refresh = 0
)

Updated Model Performance

check_hmc_diagnostics(model)

Divergences:
0 of 2000 iterations ended with a divergence.

Tree depth:
0 of 2000 iterations saturated the maximum tree depth of 10.

Energy:
E-BFMI indicated no pathological behavior.
summary(model)$summary[8:15,"Rhat"]
 theta[1]  theta[2]  theta[3]  theta[4]  theta[5]  theta[6]  theta[7]      lp__ 
1.0012360 1.0001953 1.0028185 1.0002307 0.9999944 1.0034565 0.9997221 1.0044744 

Updated Model Visualization

Updated Geometry

BRMS - An Alternative

brms_fit = brm(Efficient ~ cyl + disp + hp + drat + wt + am, 
                data = fit_df, family = bernoulli(), 
                silent = 2, refresh = 0, 
                cores = 4, chains = 4, iter = 1000)
stancode(brms_fit)
// generated with brms 2.22.0
functions {
}
data {
  int<lower=1> N;  // total number of observations
  array[N] int Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int<lower=1> Kc;  // number of population-level effects after centering
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // regression coefficients
  real Intercept;  // temporary intercept for centered predictors
}
transformed parameters {
  real lprior = 0;  // prior contributions to the log posterior
  lprior += student_t_lpdf(Intercept | 3, 0, 2.5);
}
model {
  // likelihood including constants
  if (!prior_only) {
    target += bernoulli_logit_glm_lpmf(Y | Xc, Intercept, b);
  }
  // priors including constants
  target += lprior;
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept - dot_product(means_X, b);
}

BRMS - Diagnostics

check_hmc_diagnostics(brms_fit$fit)

Divergences:
0 of 2000 iterations ended with a divergence.

Tree depth:
0 of 2000 iterations saturated the maximum tree depth of 10.

Energy:
E-BFMI indicated no pathological behavior.
rhat(brms_fit)
b_Intercept       b_cyl      b_disp        b_hp      b_drat        b_wt 
   1.016590    1.003937    1.046643    1.035852    1.037613    1.014859 
      b_am1   Intercept      lprior        lp__ 
   1.028110    1.045403    1.046189    1.040587 

BRMS - Updating Priors

brms_update = brm(Efficient ~ cyl + disp + hp + drat + wt + am, 
                data = fit_df, family = bernoulli(), 
                silent = 2, refresh = 0, 
                cores = 4, chains = 4, iter = 1000, 
                prior = set_prior("normal(0, 10)", class = "b"))
stancode(brms_update)
// generated with brms 2.22.0
functions {
}
data {
  int<lower=1> N;  // total number of observations
  array[N] int Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int<lower=1> Kc;  // number of population-level effects after centering
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // regression coefficients
  real Intercept;  // temporary intercept for centered predictors
}
transformed parameters {
  real lprior = 0;  // prior contributions to the log posterior
  lprior += normal_lpdf(b | 0, 10);
  lprior += student_t_lpdf(Intercept | 3, 0, 2.5);
}
model {
  // likelihood including constants
  if (!prior_only) {
    target += bernoulli_logit_glm_lpmf(Y | Xc, Intercept, b);
  }
  // priors including constants
  target += lprior;
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept - dot_product(means_X, b);
}

BRMS - Updated Diagnostics

check_hmc_diagnostics(brms_update$fit)

Divergences:
0 of 2000 iterations ended with a divergence.

Tree depth:
0 of 2000 iterations saturated the maximum tree depth of 10.

Energy:
E-BFMI indicated no pathological behavior.
rhat(brms_update)
b_Intercept       b_cyl      b_disp        b_hp      b_drat        b_wt 
   1.001894    1.005495    1.007951    1.003834    1.006253    1.001785 
      b_am1   Intercept      lprior        lp__ 
   1.002961    1.005025    1.007246    1.005499 

Profiling STAN Models with CmdStanR (1)

data {
  int N;  // Number of observations
  int P;  // Number of fixed effect covariates
  
  array[N] int<lower=0, upper=1> Y;  // Binary outcomes
  matrix[N, P] X;                    // Fixed effects design matrix
}
transformed data {
  matrix[N, P] Q_coef = qr_thin_Q(X) * sqrt(N-1);
  matrix[P, P] R_coef = qr_thin_R(X) / sqrt(N-1);
}
parameters {
  vector[P] beta;  // Coefficients
}
transformed parameters{
vector[P] theta;
  profile("Transform") {
    theta = R_coef * beta;
  }
}
model {
  profile("Priors") {
    beta ~ normal(0, 10);
  }
  profile("Likelihood") {
    target += bernoulli_logit_lpmf(Y| Q_coef * theta);
  }
}

Profiling STAN Models with CmdStanR (2)

model = cmdstan_model("Profile_Mod.stan")
fit = model$sample(data = data_list, chains = 1)
fit$profiles()[[1]][,c(1,3,4,5,8)]
        name  total_time forward_time reverse_time autodiff_calls
1  Generated 0.000158059  0.000158059   0.00000000              0
2     Priors 0.025370400  0.022040400   0.00332996         173154
3 Likelihood 0.142509000  0.113486000   0.02902290         173154