Time dependency
Here, we demonstrate what to do transition and observation laws depend on the current time. This time-dependent HMM is implemented as a particular case of controlled HMM.
using Distributions
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using Random
using StableRNGs
using StatsAPI
rng = StableRNG(63);
Model
We focus on the particular case of a periodic HMM with period L
. It has only one initialization vector, but L
transition matrices and L
vectors of observation distributions. As in Custom HMM structures, we need to subtype AbstractHMM
.
struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM
init::Vector{T}
trans_per::NTuple{L,Matrix{T}}
dists_per::NTuple{L,Vector{D}}
end
The interface definition is almost the same as in the homogeneous case, but we give the control variable (here the time) as an additional argument to transition_matrix
and obs_distributions
.
period(::PeriodicHMM{T,D,L}) where {T,D,L} = L
function HMMs.initialization(hmm::PeriodicHMM)
return hmm.init
end
function HMMs.transition_matrix(hmm::PeriodicHMM, t::Integer)
l = (t - 1) % period(hmm) + 1
return hmm.trans_per[l]
end
function HMMs.obs_distributions(hmm::PeriodicHMM, t::Integer)
l = (t - 1) % period(hmm) + 1
return hmm.dists_per[l]
end
Simulation
init = [0.6, 0.3, 0.1]
trans_per = (
[ # l = 1 -> mostly switch to next state
0.2 0.8 0.0
0.0 0.2 0.8
0.8 0.0 0.2
],
[ # l = 2 -> mostly switch to previous state
0.2 0.0 0.8
0.8 0.2 0.0
0.0 0.8 0.2
],
[ # l = 3 -> mostly stay in current state
0.8 0.1 0.1
0.1 0.8 0.1
0.1 0.1 0.8
],
)
dists_per = (
[Normal(1.0), Normal(2.0), Normal(3.0)],
[Normal(3.0), Normal(4.0), Normal(5.0)],
[Normal(5.0), Normal(6.0), Normal(7.0)],
)
hmm = PeriodicHMM(init, trans_per, dists_per);
Since the behavior of the model depends on control variables, we need to pass these to the simulation routine (instead of just the number of time steps T
).
control_seq = 1:10
state_seq, obs_seq = rand(rng, hmm, control_seq);
The observations mostly alternate between positive and negative values, which is coherent with negative observation means at odd times and positive observation means at even times.
obs_seq'
1×10 adjoint(::Vector{Float64}) with eltype Float64:
3.15929 4.60004 5.30241 2.07763 … 3.50298 2.51008 5.1677 0.348071
We now generate several sequences of variable lengths, for inference and learning tasks.
control_seqs = [1:rand(rng, 100:200) for k in 1:1000]
obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in eachindex(control_seqs)];
obs_seq = reduce(vcat, obs_seqs)
control_seq = reduce(vcat, control_seqs)
seq_ends = cumsum(length.(obs_seqs));
Inference
All three inference algorithms work in the same way, except that we need to provide the control sequence as the last positional argument.
best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)
([1, 3, 3, 3, 2, 2, 3, 2, 2, 3 … 1, 1, 1, 3, 1, 2, 1, 1, 2, 1], [-338.1598788374384, -264.02201367838444, -363.3751457229987, -327.56205924276986, -241.77781132672078, -176.98431486613364, -307.55698780914526, -303.6807200516033, -215.49999838420771, -273.6631616955435 … -289.02560368764665, -248.35790041883394, -180.25383246901902, -201.31569791813394, -280.4615462353202, -227.57888520433457, -247.0328516674336, -345.2831898774707, -196.76362253507708, -284.2881473064425])
For Viterbi, unsurprisingly, the most likely state sequence aligns with the sign of the observations.
vcat(obs_seq', best_state_seq')
2×149134 Matrix{Float64}:
0.434712 4.89023 6.0682 4.09868 … 3.72124 3.20197 2.96069 3.03086
1.0 3.0 3.0 3.0 1.0 1.0 2.0 1.0
Learning
When estimating parameters for a custom subtype of AbstractHMM
, we have to override the fitting procedure after forward-backward, with an additional control_seq
positional argument. The key is to split the observations according to which periodic parameter they belong to.
function StatsAPI.fit!(
hmm::PeriodicHMM{T},
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
L, N = period(hmm), length(hmm)
hmm.init .= zero(T)
for l in 1:L
hmm.trans_per[l] .= zero(T)
end
for k in eachindex(seq_ends)
t1, t2 = HMMs.seq_limits(seq_ends, k)
hmm.init .+= γ[:, t1]
for l in 1:L
first_time_trans_l = if l > 1
t1 + l - 2
else
t1 + l - 2 + L
end
hmm.trans_per[l] .+= sum(ξ[first_time_trans_l:L:t2])
end
end
hmm.init ./= sum(hmm.init)
for l in 1:L, row in eachrow(hmm.trans_per[l])
row ./= sum(row)
end
for l in 1:L
times_l = Int[]
for k in eachindex(seq_ends)
t1, t2 = HMMs.seq_limits(seq_ends, k)
append!(times_l, (t1 + l - 1):L:t2)
end
for i in 1:N
HMMs.fit_in_sequence!(hmm.dists_per[l], i, obs_seq[times_l], γ[i, times_l])
end
end
for l in 1:L
@assert HMMs.valid_hmm(hmm, l)
end
return nothing
end
Now let's test our procedure with a reasonable guess.
init_guess = [0.4, 0.2, 0.3]
trans_per_guess = ntuple(_ -> [
0.4 0.3 0.3
0.3 0.4 0.3
0.3 0.3 0.4
], Val(3))
dists_per_guess = (
[Normal(1.5), Normal(2.2), Normal(2.5)],
[Normal(3.5), Normal(4.2), Normal(4.5)],
[Normal(5.5), Normal(6.2), Normal(6.5)],
)
hmm_guess = PeriodicHMM(init_guess, trans_per_guess, dists_per_guess);
Naturally, Baum-Welch also requires knowing control_seq
.
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends);
first(loglikelihood_evolution), last(loglikelihood_evolution)
(-255695.48690011754, -243058.7258113099)
Did we do well?
cat(transition_matrix(hmm_est, 1), transition_matrix(hmm, 1); dims=3)
3×3×2 Array{Float64, 3}:
[:, :, 1] =
0.186413 0.773622 0.0399647
0.0299322 0.146951 0.823117
0.796874 0.0112494 0.191877
[:, :, 2] =
0.2 0.8 0.0
0.0 0.2 0.8
0.8 0.0 0.2
cat(transition_matrix(hmm_est, 2), transition_matrix(hmm, 2); dims=3)
3×3×2 Array{Float64, 3}:
[:, :, 1] =
0.177002 0.0286042 0.794393
0.860977 0.121398 0.0176248
0.0344084 0.789656 0.175936
[:, :, 2] =
0.2 0.0 0.8
0.8 0.2 0.0
0.0 0.8 0.2
cat(transition_matrix(hmm_est, 3), transition_matrix(hmm, 3); dims=3)
3×3×2 Array{Float64, 3}:
[:, :, 1] =
0.792613 0.109302 0.0980847
0.0997413 0.799007 0.101252
0.114763 0.0867425 0.798494
[:, :, 2] =
0.8 0.1 0.1
0.1 0.8 0.1
0.1 0.1 0.8
map(mean, hcat(obs_distributions(hmm_est, 1), obs_distributions(hmm, 1)))
3×2 Matrix{Float64}:
1.01396 1.0
1.94455 2.0
3.0178 3.0
map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2)))
3×2 Matrix{Float64}:
2.99407 3.0
4.01272 4.0
4.98621 5.0
map(mean, hcat(obs_distributions(hmm_est, 3), obs_distributions(hmm, 3)))
3×2 Matrix{Float64}:
4.97559 5.0
6.0085 6.0
7.02123 7.0
This page was generated using Literate.jl.