Control dependency

Here, we give a example of controlled HMM (also called input-output HMM), in the special case of Markov switching regression.

using Distributions
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using LinearAlgebra
using Random
using StableRNGs
using StatsAPI
rng = StableRNG(63);

Model

A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM. We can represent it with the following subtype of AbstractHMM (see Custom HMM structures), which has one vector of coefficients $\beta_i$ per state.

struct ControlledGaussianHMM{T} <: AbstractHMM
    init::Vector{T}
    trans::Matrix{T}
    dist_coeffs::Vector{Vector{T}}
end

In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$. Controls must be provided to both transition_matrix and obs_distributions even if they are only used by one.

function HMMs.initialization(hmm::ControlledGaussianHMM)
    return hmm.init
end

function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
    return hmm.trans
end

function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control::AbstractVector)
    return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:length(hmm)]
end

In this case, the transition matrix does not depend on the control.

Simulation

d = 3
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dist_coeffs = [-ones(d), ones(d)]
hmm = ControlledGaussianHMM(init, trans, dist_coeffs);

Simulation requires a vector of controls, each being a vector itself with the right dimension.

Let us build several sequences of variable lengths.

control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:100];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];

obs_seq = reduce(vcat, obs_seqs)
control_seq = reduce(vcat, control_seqs)
seq_ends = cumsum(length.(obs_seqs));

Inference

Not much changes from the case with simple time dependency.

best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)
([1, 1, 1, 1, 1, 1, 1, 1, 1, 1  …  1, 1, 2, 2, 2, 2, 1, 1, 1, 1], [-202.79725491293115, -351.463887396619, -240.70481689143045, -306.1876965193646, -205.83403144660204, -344.20844574635, -275.5017410086822, -375.18392177084576, -358.0138760764081, -332.34837807761323  …  -313.82869802225645, -259.8444424628965, -296.430378773534, -317.6218357761033, -225.05333031619512, -356.2040082887087, -312.96399725767657, -347.48372196251154, -343.47569478312835, -291.11447346891066])

Learning

Once more, we override the fit! function. The state-related parameters are estimated in the standard way. Meanwhile, the observation coefficients are given by the formula for weighted least squares.

function StatsAPI.fit!(
    hmm::ControlledGaussianHMM{T},
    fb_storage::HMMs.ForwardBackwardStorage,
    obs_seq::AbstractVector,
    control_seq::AbstractVector;
    seq_ends::AbstractVector{Int},
) where {T}
    (; γ, ξ) = fb_storage
    N = length(hmm)

    hmm.init .= 0
    hmm.trans .= 0
    for k in eachindex(seq_ends)
        t1, t2 = HMMs.seq_limits(seq_ends, k)
        hmm.init .+= γ[:, t1]
        hmm.trans .+= sum(ξ[t1:t2])
    end
    hmm.init ./= sum(hmm.init)
    for row in eachrow(hmm.trans)
        row ./= sum(row)
    end

    U = reduce(hcat, control_seq)'
    y = obs_seq
    for i in 1:N
        W = sqrt.(Diagonal(γ[i, :]))
        hmm.dist_coeffs[i] = (W * U) \ (W * y)
    end
end

Now we put it to the test.

init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.3 0.7]
dist_coeffs_guess = [-1.1 * ones(d), 1.1 * ones(d)]
hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
first(loglikelihood_evolution), last(loglikelihood_evolution)
(-27281.514867506998, -26898.8371374637)

How did we perform?

cat(hmm_est.trans, hmm.trans; dims=3)
2×2×2 Array{Float64, 3}:
[:, :, 1] =
 0.696406  0.303594
 0.203845  0.796155

[:, :, 2] =
 0.7  0.3
 0.2  0.8
hcat(hmm_est.dist_coeffs[1], hmm.dist_coeffs[1])
3×2 Matrix{Float64}:
 -0.994923  -1.0
 -0.986503  -1.0
 -0.9994    -1.0
hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2])
3×2 Matrix{Float64}:
 0.987409  1.0
 0.994877  1.0
 1.00376   1.0

This page was generated using Literate.jl.