Dynamic Linear Models with Switching in Stan

Dynamic Linear Models with Switching in Stan

In this post we are trying to reproduce Example 6.22 presented in the book “Time Series Analysis and its Applications with R examples”, by Robert H. Shumway and David S. Stoffer. We will analyze the U.S. monthly pneumonia and influenza mortality using Dynamic Linear Models with Switching. This kind of models attempt to generalize state space models to include the possibility of regime changes occurring over time. The idea is to use hidden markov models (HMM) to marginalize the hidden discrete states allowing the state space model to behave differently depending on the underlying unobserved state. The example in the book utilizes the exact likelihood, but in this post we will try to express the model explicitly with Stan. For an example of state-space model with Stan see Juho Kokkala’s blog post http://www.juhokokkala.fi/blog/posts/kalman-filter-style-recursion-to-marginalize-state-variables-to-speed-up-stan-inference/ or Jeffrey B. Arnold’s library https://jrnold.github.io/ssmodels-in-stan/filtering-and-smoothing.html, while an example of HMM can be found in Stan’s manual https://mc-stan.org/docs/stan-users-guide/hmms.html.

require(tidyverse)
require(rstan)
require(mvtnorm)
require(cmdstanr)
require(posterior)
library(astsa)

Let’s have a look at the date representing mortality rate from 1968 to 1978:

plot.default(flu, ylab="Mortality", xlab="", cex.axis=1.5, cex.lab=1.5, type="o")

The model consists of three structural components. The first component is an AR(2) process chosen to represent the periodic (seasonal) component of the data:

\[ x_{t1} = \alpha_{1}x_{t-1,1} + \alpha_{2}x_{t-2,1} + w_{t1} \] where \(w_{t1}\) is white noise with variance = \(\sigma_{1}^2\).

The second component is an AR(1) process with a nonzero constant term, representing the sharp rise in the data during an epidemic:

\[ x_{t2} = \beta_{0} + \beta_{1}x_{t-1,2} + w_{t2} \] where \(w_{t2}\) is white noise with variance = \(\sigma_{2}^2\). The third component is a fixed trend component:

\[ x_{t3} = x_{t-1,3} + w_{t3} \]

In the book this process is assumed to be deterministic because estimation was unstable, we will try try to estimate it nonetheless. The model can be expressed in state-space form as:

\[ \begin{pmatrix} x_{t,1} \\ x_{t-1,1} \\ x_{t,2} \\ x_{t,3} \end{pmatrix} = \begin{bmatrix} \alpha_{1} & \alpha_{2} & 0 & 0 \\ 1 & 0 & 0 & 0 \\ 0 & 0 & \beta_{1} & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} \begin{pmatrix} x_{t-1,1} \\ x_{t-2,1} \\ x_{t-1,2} \\ x_{t-1,3} \end{pmatrix} + \begin{pmatrix} 0 \\ 0 \\ \beta_{0} \\ 0 \end{pmatrix} + \begin{pmatrix} w_{t1} \\ 0 \\ w_{t2} \\ 0 \end{pmatrix} \]

The observation equation is:

\[ y_{t} = A_{t}x_{t} + v_{t} \]

We want periods of normal influenza mortality (regime 1) are modeled as:

\[ y_{t} = x_{t1} + x_{t3} + v_{t} \]

while during epidemics (regime 2) mortality is modelled as:

\[ y_{t} = x_{t1} + x_{t2} + x_{t3} + v_{t} \]

Therefore, the matrix A can take two forms depending on the regime: [1,0,0,1] for no epidemic, and [1,0,1,1] for epidemic.

We decided to model the process in Stan as follows:

// Model for the influenza mortality, mode of 4 states and 2 regimes

data {
  int<lower=0> N;
  array[N] real y;
  vector[4] m0;
  matrix[4, 4] P0;
}
transformed data {
  int D = 4; // the number of states
  matrix[D, D] I; // the identity matrix
  I = diag_matrix(rep_vector(1, D));
}
parameters {
  vector[2] alpha;
  real<lower=0> beta0;
  real beta1;
  vector<lower=0>[2] sigma;
  real<lower=0> sigmav;
  array[2] simplex[2] theta; // state transistion probabilities
}
transformed parameters {
  matrix[D, D] Q; // process error matrix
  real R; // observation error matrix
  matrix[D, D] Z; // transition matrix
  matrix[D, D] Zt; // transition matrix transpose
  array[2, N] vector[D] m; // filtered states
  array[2, N] vector[D] m_pred; // predicted states
  array[2, N] matrix[D, D] P; // filtered covariance matrix
  array[2, N] matrix[D, D] P_pred; // predicted covariance matrix
  array[2, N] real S; // variance of the prediction errors
  array[2] row_vector[D] A;  // observation model matrices, depengin on the regime
  vector[D] u; // interceps of the process
  array[2, N] vector[D] K; // Kalman gains
  array[2, N] real v; // prediction errors
  array[2, N] real mu; 
  
  array[2] real acc; // HMM temporary probabilities
  array[N] vector[2] gamma; // HMM  forward values
  {
    row_vector[D] H; // Just a variable to store the different A matrices depending on the regime
    vector[D] Ht; // H transpose
    Z = [[alpha[1], alpha[2], 0, 0], 
         [1, 0, 0, 0], 
         [0, 0, beta1, 0],
         [0, 0, 0, 1]];
    Zt = Z';
    A[1] = [1, 0, 0, 1];
    A[2] = [1, 0, 1, 1];
    u = [0, 0, beta0, 0]';
    
    Q = diag_matrix([sigma[1] ^ 2, 0, sigma[2] ^ 2, 0]');
    R = sigmav ^ 2;
    
    // Initialize the states, we can also set these as paramteres to estimate
    m_pred[1][1] = m0;
    m_pred[2][1] = m0;
    P_pred[1][1] = P0;
    P_pred[2][1] = P0;
    
    // Kalman filter see https://jrnold.github.io/ssmodels-in-stan/filtering-and-smoothing.html
    for (t in 1 : N) {
      for (k in 1 : 2) {
        if (t > 1) {
          m_pred[k][t] = Z * m[k][t - 1] + u;
          P_pred[k][t] = Z * P[k][t - 1] * Zt + Q;
        }
        H = A[k];
        Ht = H';
        v[k][t] = y[t] - H * m_pred[k][t];
        S[k][t] = H * P_pred[k][t] * Ht + R;
        K[k][t] = (P_pred[k][t] * Ht) / S[k][t];
        m[k][t] = m_pred[k][t] + K[k][t] * v[k][t];
        P[k][t] = (I - K[k][t] * H) * P_pred[k][t];
        mu[k][t] = H * m_pred[k][t];
      }
    }
  }
  
  // HMM marginalization see https://mc-stan.org/docs/stan-users-guide/hmms.html
  for (k in 1 : 2) 
    gamma[1, k] = normal_lpdf(y[1] | mu[k][1], sqrt(S[k][1]));
  for (t in 2 : N) {
    for (k in 1 : 2) {
      for (j in 1 : 2) {
        acc[j] = gamma[t - 1, j] + log(theta[j, k])
                 + normal_lpdf(y[t] | mu[j][t], sqrt(S[j][t]));
      }
      gamma[t, k] = log_sum_exp(acc);
    }
  }
}
model {
  sigma ~ exponential(1);
  sigmav ~ exponential(1);
  alpha ~ normal(0, 1);
  beta0 ~ exponential(1);
  beta1 ~ normal(1, 0.5);
  theta[1, 1] ~ beta(10, 2);
  theta[2, 2] ~ beta(10, 2);
  theta[2, 1] ~ beta(2, 10);
  theta[1, 2] ~ beta(2, 10);

  target += log_sum_exp(gamma[N]);
}
generated quantities {
  array[N] vector[2] regimes;
  for (t in 1 : N) {
    regimes[t] = softmax(gamma[t]);
  }
}

The code seems to be very involved but you can see that apart for a large amount of variables declared, most of the code is mode of the main loop for the kalman filter and the hmm marginalization. Notice that the Kalman filter loop is run twice for both values of matrix A representing the two possible regimes, therefore we need to have two value for each variable of the Kalman filter, as you can see in the declaration at the top. Later during the HMM forward algorithm we calculate the probability of each regime (which differ only for the value of A).

m <- cmdstan_model("models/Misc/flu.stan")
fit <- m$sample(data=list(N=length(flu), y=c(0, diff(as.vector(flu))), m0=rep(0, 4), P0=diag(c(1,1,1,1))), parallel_chains = 4, cores = 4,iter_warmup = 250, iter_sampling = 250)

Parameters estimations are vaguely close to the ones in the book, we additionally obtain transition probabilities estimations:

fit$print(max_rows = 12, digits = 4)
##    variable     mean   median     sd    mad       q5      q95   rhat ess_bulk
##  lp__       179.2173 179.5520 2.1490 2.1349 175.3534 182.1791 1.0099      336
##  alpha[1]     1.2787   1.2769 0.0552 0.0546   1.1880   1.3708 1.0097      323
##  alpha[2]    -0.3632  -0.3610 0.0449 0.0457  -0.4384  -0.2917 1.0056      313
##  beta0        0.8701   0.5959 0.9020 0.6077   0.0523   2.6551 1.0008      534
##  beta1        0.5458   0.5559 0.1520 0.1536   0.2873   0.7814 1.0033      557
##  sigma[1]     0.0210   0.0209 0.0020 0.0021   0.0179   0.0241 1.0041      513
##  sigma[2]     0.1671   0.1656 0.0244 0.0253   0.1331   0.2091 1.0014      678
##  sigmav       0.0034   0.0033 0.0022 0.0027   0.0003   0.0072 1.0192      285
##  theta[1,1]   0.9051   0.9095 0.0298 0.0270   0.8506   0.9455 1.0043      662
##  theta[2,1]   0.2264   0.2225 0.0622 0.0619   0.1319   0.3343 1.0055      572
##  theta[1,2]   0.0949   0.0905 0.0298 0.0270   0.0545   0.1494 1.0043      662
##  theta[2,2]   0.7736   0.7775 0.0622 0.0619   0.6657   0.8681 1.0055      572
##  ess_tail
##       583
##       282
##       336
##       337
##       483
##       374
##       635
##       388
##       477
##       434
##       477
##       434
## 
##  # showing 12 of 13011 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

Regime estimation are reasonable, even if not so good as in the book:

k <- 6; y <- as.matrix(flu); num <- length(y); nstate <- 4;
s <- fit$draws(variables = "regimes") %>% merge_chains() %>% colMeans() %>% matrix(nrow=nrow(y)) 
Time <- as.matrix(time(flu))
regime <- ifelse(s[,1] < s[,2], 1, 2);
plot(Time, y, type="n", ylab="", cex=2, cex.axis=2, xlab="")
grid(lty=2); lines(Time, y, col=gray(.7))
text(Time, t(y), col=regime, labels=regime, cex=2)

To obtain the estimation of the final filtered states, we get the filtered states for the two regimes and we calculated the average weighted using the probability for each regimes:

a <- fit$draws(variables = "m") %>% merge_chains() %>% colMeans()
m1 <- matrix(a[seq(1, 1056, 2)], ncol=4)
m2 <- matrix(a[seq(2, 1056, 2)], ncol=4)
p1 <- matrix(rep(s[,1], 4), ncol=4)
p2 <- matrix(rep(s[,2], 4), ncol=4)
m3 <- m1 * p1 + m2 * p2
matplot(Time, m3[,-2], type="o", cex=1.5, ylab="", cex.axis=2)

Estimation for the trend process seems to be strongly affected by the epidemic oscillation, unlike the state obtained in the book. It is possible that additional precautions are needed in the definition of our model to better match the results. We finally show one-month-ahead predictions with the variability (2*innovations-variance):

mu <- fit$draws(variables = "mu") %>% merge_chains() %>% colMeans()
S <- fit$draws(variables = "S") %>% merge_chains() %>% colMeans()
mu1 <- mu[seq(1, 264, 2)]
mu2 <- mu[seq(2, 264, 2)]
mu3 <- mu1 * s[,1] + mu2 * s[,2]
S1 <- S[seq(1, 264, 2)]
S2 <- S[seq(2, 264, 2)]
S3 <- 2*sqrt(S1 * s[,1] + S2 * s[,2])
plot(Time, mu3, type="n", ylab="", xlab="", ylim=c(0,1), cex=2, cex.axis=2)
grid(lty=2);
points(Time, as.vector(y), pch=16, cex=1.5)
xx = c(Time, rev(Time))
yy = c(mu2-S3, rev(mu3+S3))
polygon(xx, yy, border=8, col=gray(.6, alpha=.3))

Comments