Control dependency
Here, we give a example of controlled HMM (also called input-output HMM), in the special case of Markov switching regression.
using Distributions
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using LinearAlgebra
using Random
using StableRNGs
using StatsAPI
rng = StableRNG(63);
Model
A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM. We can represent it with the following subtype of AbstractHMM
(see Custom HMM structures), which has one vector of coefficients $\beta_i$ per state.
struct ControlledGaussianHMM{T} <: AbstractHMM
init::Vector{T}
trans::Matrix{T}
dist_coeffs::Vector{Vector{T}}
end
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$. Controls must be provided to both transition_matrix
and obs_distributions
even if they are only used by one.
function HMMs.initialization(hmm::ControlledGaussianHMM)
return hmm.init
end
function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
return hmm.trans
end
function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control::AbstractVector)
return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:length(hmm)]
end
In this case, the transition matrix does not depend on the control.
Simulation
d = 3
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dist_coeffs = [-ones(d), ones(d)]
hmm = ControlledGaussianHMM(init, trans, dist_coeffs);
Simulation requires a vector of controls, each being a vector itself with the right dimension.
Let us build several sequences of variable lengths.
control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];
obs_seq = reduce(vcat, obs_seqs)
control_seq = reduce(vcat, control_seqs)
seq_ends = cumsum(length.(obs_seqs));
Inference
Not much changes from the case with simple time dependency.
best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)
([1, 1, 2, 2, 2, 2, 2, 2, 2, 2 … 1, 1, 1, 1, 1, 1, 1, 1, 2, 2], [-264.53078164120643, -313.58808733243774, -256.75478343926085, -227.82892060129916, -282.8800770276695, -306.6552069996733, -196.07142048722181, -273.4570074581193, -324.77996086102576, -306.4672793317822 … -262.92800021685736, -240.004468851676, -293.6457092655249, -317.93357458165605, -313.5325464733667, -243.9660817310352, -186.18103561888216, -272.81260791498477, -231.25232201415807, -193.76726331060343])
Learning
Once more, we override the fit!
function. The state-related parameters are estimated in the standard way. Meanwhile, the observation coefficients are given by the formula for weighted least squares.
function StatsAPI.fit!(
hmm::ControlledGaussianHMM{T},
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
N = length(hmm)
hmm.init .= 0
hmm.trans .= 0
for k in eachindex(seq_ends)
t1, t2 = HMMs.seq_limits(seq_ends, k)
hmm.init .+= γ[:, t1]
hmm.trans .+= sum(ξ[t1:t2])
end
hmm.init ./= sum(hmm.init)
for row in eachrow(hmm.trans)
row ./= sum(row)
end
U = reduce(hcat, control_seq)'
y = obs_seq
for i in 1:N
W = sqrt.(Diagonal(γ[i, :]))
hmm.dist_coeffs[i] = (W * U) \ (W * y)
end
end
Now we put it to the test.
init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.3 0.7]
dist_coeffs_guess = [-1.1 * ones(d), 1.1 * ones(d)]
hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
first(loglikelihood_evolution), last(loglikelihood_evolution)
(-264718.563864405, -261242.07851449447)
How did we perform?
cat(hmm_est.trans, hmm.trans; dims=3)
2×2×2 Array{Float64, 3}:
[:, :, 1] =
0.702158 0.297842
0.201116 0.798884
[:, :, 2] =
0.7 0.3
0.2 0.8
hcat(hmm_est.dist_coeffs[1], hmm.dist_coeffs[1])
3×2 Matrix{Float64}:
-0.999216 -1.0
-1.00168 -1.0
-1.00624 -1.0
hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2])
3×2 Matrix{Float64}:
0.996911 1.0
1.00333 1.0
1.00181 1.0
This page was generated using Literate.jl.