Time dependency

Here, we demonstrate what to do transition and observation laws depend on the current time. This time-dependent HMM is implemented as a particular case of controlled HMM.

using Distributions
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using Random
using StableRNGs
using StatsAPI
rng = StableRNG(63);


We focus on the particular case of a periodic HMM with period L. It has only one initialization vector, but L transition matrices and L vectors of observation distributions. As in Custom HMM structures, we need to subtype AbstractHMM.

struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM

The interface definition is almost the same as in the homogeneous case, but we give the control variable (here the time) as an additional argument to transition_matrix and obs_distributions.

period(::PeriodicHMM{T,D,L}) where {T,D,L} = L

function HMMs.initialization(hmm::PeriodicHMM)
    return hmm.init

function HMMs.transition_matrix(hmm::PeriodicHMM, t::Integer)
    l = (t - 1) % period(hmm) + 1
    return hmm.trans_per[l]

function HMMs.obs_distributions(hmm::PeriodicHMM, t::Integer)
    l = (t - 1) % period(hmm) + 1
    return hmm.dists_per[l]


init = [0.6, 0.4]
trans_per = ([0.7 0.3; 0.2 0.8], [0.3 0.7; 0.8 0.2])
dists_per = ([Normal(-1.0), Normal(-2.0)], [Normal(+1.0), Normal(+2.0)])
hmm = PeriodicHMM(init, trans_per, dists_per);

Since the behavior of the model depends on control variables, we need to pass these to the simulation routine (instead of just the number of time steps T).

control_seq = 1:10
state_seq, obs_seq = rand(rng, hmm, control_seq);

The observations mostly alternate between positive and negative values, which is coherent with negative observation means at odd times and positive observation means at even times.

1×10 adjoint(::Vector{Float64}) with eltype Float64:
 -0.840714  3.60004  -0.697595  1.07763  …  1.51008  -0.832302  0.348071

We now generate several sequences of variable lengths, for inference and learning tasks.

control_seqs = [1:rand(rng, 100:200) for k in 1:1000]
obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in eachindex(control_seqs)];

obs_seq = reduce(vcat, obs_seqs)
control_seq = reduce(vcat, control_seqs)
seq_ends = cumsum(length.(obs_seqs));


All three inference algorithms work in the same way, except that we need to provide the control sequence as the last positional argument.

best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)
([2, 2, 1, 2, 1, 1, 1, 1, 2, 2  …  1, 2, 2, 1, 1, 1, 1, 2, 2, 1], [-327.4553030088472, -256.729891367137, -371.58047267702534, -329.32608569081737, -236.5895986924772, -174.80325979593096, -300.38318599112836, -294.86014440799255, -203.56580911377492, -268.1256974881177  …  -276.45400897666735, -258.9507164381684, -171.67438564667265, -200.5544421470556, -283.08920234552613, -231.8716989192954, -246.73175667426827, -341.54809613803695, -193.02631559761073, -279.25559847293437])

For Viterbi, unsurprisingly, the most likely state sequence aligns with the sign of the observations.

vcat(obs_seq', best_state_seq')
2×149134 Matrix{Float64}:
 -3.56529  2.89023  -0.931799  2.09868  …  -2.79803  2.96069  -0.969142
  2.0      2.0       1.0       2.0          2.0      2.0       1.0


When estimating parameters for a custom subtype of AbstractHMM, we have to override the fitting procedure after forward-backward, with an additional control_seq positional argument. The key is to split the observations according to which periodic parameter they belong to.

function StatsAPI.fit!(
) where {T}
    (; γ, ξ) = fb_storage
    L, N = period(hmm), length(hmm)

    hmm.init .= zero(T)
    for l in 1:L
        hmm.trans_per[l] .= zero(T)
    for k in eachindex(seq_ends)
        t1, t2 = HMMs.seq_limits(seq_ends, k)
        hmm.init .+= γ[:, t1]
        for l in 1:L
            hmm.trans_per[l] .+= sum(ξ[(t1 + l - 1):L:t2])
    hmm.init ./= sum(hmm.init)
    for l in 1:L, row in eachrow(hmm.trans_per[l])
        row ./= sum(row)

    for l in 1:L
        times_l = Int[]
        for k in eachindex(seq_ends)
            t1, t2 = HMMs.seq_limits(seq_ends, k)
            append!(times_l, (t1 + l - 1):L:t2)
        for i in 1:N
            HMMs.fit_in_sequence!(hmm.dists_per[l], i, obs_seq[times_l], γ[i, times_l])

    for l in 1:L
        @assert HMMs.valid_hmm(hmm, l)
    return nothing

Now let's test our procedure with a reasonable guess.

init_guess = [0.7, 0.3]
trans_per_guess = ([0.6 0.4; 0.3 0.7], [0.4 0.6; 0.7 0.3])
dists_per_guess = ([Normal(-1.1), Normal(-2.1)], [Normal(+1.1), Normal(+2.1)])
hmm_guess = PeriodicHMM(init_guess, trans_per_guess, dists_per_guess);

Naturally, Baum-Welch also requires knowing control_seq.

hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends);
first(loglikelihood_evolution), last(loglikelihood_evolution)
(-227885.2961549863, -226890.71955893945)

Did we do well?

cat(transition_matrix(hmm_est, 1), transition_matrix(hmm, 1); dims=3)
2×2×2 Array{Float64, 3}:
[:, :, 1] =
 0.693737  0.306263
 0.220124  0.779876

[:, :, 2] =
 0.7  0.3
 0.2  0.8
cat(transition_matrix(hmm_est, 2), transition_matrix(hmm, 2); dims=3)
2×2×2 Array{Float64, 3}:
[:, :, 1] =
 0.327157  0.672843
 0.797991  0.202009

[:, :, 2] =
 0.3  0.7
 0.8  0.2
map(mean, hcat(obs_distributions(hmm_est, 1), obs_distributions(hmm, 1)))
2×2 Matrix{Float64}:
 -1.00183  -1.0
 -2.01198  -2.0
map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2)))
2×2 Matrix{Float64}:
 0.972284  1.0
 2.01909   2.0

