Fast and generic Hidden Markov Models

Guillaume Dalle

École Polytechnique Fédérale de Lausanne

2024-07-12

Introduction

About me

What is a Hidden Markov Model?

  • Statistical model for a temporal process
  • Hidden state \(X_t\) evolves through time
  • Observations \(Y_t\) depend on the state
  • Basically a noisy Markov chain

The first Markov chain (Von Hilgers and Langville 2006)

Applications

  • Speech processing
  • Bioinformatics
  • Predictive maintenance

The math

Some notation

  • \(N\) states, \(M\) observations
  • Initial probabilities \(p\)
  • Transition matrix \(A\)
  • Emission matrix \(B\)

\[\mathbb{P}(X_0 = i) = p_i\] \[\mathbb{P}(X_{t} = j | X_{t-1} = i) = A_{i,j}\] \[\mathbb{P}(Y_{t} = k | X_t = i) = B_{i,k}\]

Example: my JuliaCon attendance

stateDiagram-v2
  direction LR
  state "JuliaCon\nonline" as online
  state "JuliaCon\nin Europe" as europe
  state "JuliaCon\noverseas" as overseas
  online --> overseas: 1
  overseas --> online: 1
  overseas --> europe: 2
  europe --> overseas: 1
  overseas --> overseas: 3
  online --> online: 2

\[p = \begin{pmatrix} 0 & 0 & 1 \end{pmatrix}\]

\[A = \begin{pmatrix} 2/3 & 1/3 & 0 \\ 1/6 & 3/6 & 2/6 \\ 0 & 1 & 0 \end{pmatrix}\]

stateDiagram-v2
  state "JuliaCon\nonline" as online
  state "JuliaCon\nin Europe" as europe
  state "JuliaCon\noverseas" as overseas
  state "I'm attending" as attending
  state "I'm missing" as missing
  online --> attending: 3
  overseas --> missing: 6
  europe --> attending: 1
  europe --> missing: 1

\[B = \begin{pmatrix} 1 & 0 \\ 1/2 & 1/2 \\ 0 & 1 \end{pmatrix}\]

Main algorithms

Given an HMM, one may want to:

  • Observation likelihood \(\mathbb{P}_\theta(Y_{1:T})\) \(\implies\) Forward
  • Most likely states \(\underset{X_{1:T}}{\max} \mathbb{P}_\theta(X_{1:T} | Y_{1:T})\) \(\implies\) Viterbi
  • Best parameters \(\underset{\theta}{\max} \mathbb{P}_\theta(Y_{1:T})\) \(\implies\) Baum-Welch

See the tutorial by Rabiner (1989) for details.

The code

Basics

Model creation:

using HiddenMarkovModels, Distributions, Random
rng = Random.default_rng()

p = [0.0, 0.0, 1.0]
A = [2/3 1/3 0; 1/6 3/6 2/6; 0 1 0]
B = [1 0; 1/2 1/2; 0 1]
dists = Categorical.(Vector.(eachrow(B)))
model = HMM(p, A, dists);

Simulation:

rand(rng, model, 5)
(state_seq = [3, 2, 2, 2, 3], obs_seq = [2, 2, 1, 1, 2])

Inference

JuliaCon attendance data:

state_seq = [2, 2, 2, 2, 3, 2, 1, 1, 1, 2, 3]
obs_seq = [2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1];

Likelihood:

logdensityof(model, obs_seq)
-6.954935049252169

State sequence:

best_state_seq, _ = viterbi(model, obs_seq)
([3, 2, 3, 2, 3, 2, 1, 1, 1, 2, 2], [-10.057409634808385])

Learning

long_obs_seq = rand(rng, model, 1000).obs_seq
p0, A0 = rand_prob_vec(3), rand_trans_mat(3)
model_init = HMM(p0, A0, dists)

model_est, logL = baum_welch(
  model, long_obs_seq);
transition_matrix(model_est)
3×3 Matrix{Float64}:
 0.660929  0.339071  0.0
 0.203221  0.352171  0.444608
 0.0       1.0       0.0
probs.(obs_distributions(model_est))
3-element Vector{Vector{Float64}}:
 [1.0, 0.0]
 [0.5311920571169335, 0.4688079428830663]
 [0.0, 1.0]

Why a new package?

Need for a generic, fast and reliable library.

More generality

Number types

  • I need Float64 everywhere or I get really stressed out
  • I can handle any real number they throw at me
import ForwardDiff

f(μ) = logdensityof(
  HMM(
    [0.9, 0.1], [0.7 0.3; 0.2 0.8],
    Normal.(μ, ones(2))
  ),
  [-1.0, -1.3, -0.6, 1.4, 1.2, 0.9]
)

ForwardDiff.gradient(f, [-2.0, 2.0])
2-element Vector{Float64}:
  2.9993145840880095
 -2.7406261032012837

Benefits

Variable precision, automatic differentiation, logarithmic storage.

Challenges

Parametric types everywhere, handling promotion correctly.

Matrix types

  • The only good matrix is a dense Matrix
  • Who’s afraid of the big bad SparseMatrixCSC?
using SparseArrays

model_sp = HMM(p, sparse(A), dists)
model_sp_est, _ = baum_welch(
  model_sp, long_obs_seq
)
transition_matrix(model_sp_est)
3×3 SparseMatrixCSC{Float64, Int64} with 6 stored entries:
 0.660929  0.339071   ⋅ 
 0.203221  0.352171  0.444608
  ⋅        1.0        ⋅ 

Benefits

Large state spaces, realistic transition structures.

Challenges

Generic updates during parameter estimation.

Observation types

  • Distributions.jl is my whole life
  • Give me a sampler + loglikelihood and let’s go
struct MyDist
  length::Int
end
Base.rand(
  rng::AbstractRNG, dist::MyDist
) = randstring(rng, dist.length)
model_str = HMM(p, A, MyDist.(1:3))
rand(rng, model_str, 5)
(state_seq = [3, 2, 3, 2, 3], obs_seq = ["wJi", "s1", "lDn", "Kk", "09G"])

Benefits

Arbitrary observations (strings, point processes, images).

Challenges

What is the correct interface for these distributions?

Automatic differentiation

  • There are mutations, Zygote.jl will never survive
  • My middle initial is C for ChainRules.jl
import Zygote

f(μ) = logdensityof(
  HMM(
    [0.9, 0.1], [0.7 0.3; 0.2 0.8],
    Normal.(μ, ones(2))
  ),
  [-1.0, -1.3, -0.6, 1.4, 1.2, 0.9]
)

Zygote.gradient(f, [-2.0, 2.0])
([2.999314584088009, -2.740626103201284],)

Benefits

Efficient gradient computations.

Challenges

Writing down the math.

Controls

  • My HMM never changes, it’s my rock
  • Bring in the exogenous variables!
import HiddenMarkovModels as HMMs

struct Dirac; val; end
struct DiracHMM <: AbstractHMM; N::Int; end

HMMs.initialization(model::DiracHMM) = ones(model.N) / model.N
HMMs.transition_matrix(model::DiracHMM, control) = ones(model.N, model.N) / model.N
HMMs.obs_distributions(::DiracHMM, control) = Dirac(control)

Benefits

Temporal heterogeneity can be modeled.

Challenges

More complicated estimation methods.

More speed

Type stability

  • Essential for Just-In-Time compilation to work well
  • Tested with JET.jl for all major subroutines

No allocations

  • Crucial for performance in hot loops
  • Tested with @allocated for all major subroutines

Multithreading

  • Inference and estimation on multiple sequences are embarrassingly parallel
  • Implemented with Threads.@threads, potential for multithreaded reductions too

More reliability

Good practices in package development

Some encouraging benchmarks

Comparison against Python and Julia competitors

Conclusion

The role of interfaces

  • AbstractArray for transition matrices
  • DensityInterface.jl for observation distributions
  • AbstractHMM with a handful of functions to handle simulation, inference and learning

Formal specification

For interfaces with precise requirements, “the doc is the API” cannot suffice.

Publishing software

  • Paper in the Journal of Open Source Software (Dalle 2024)
  • Pleasant and productive open review process
  • Packages are valuable research contributions
  • They need to be cited and recognized

References

Dalle, Guillaume. 2024. HiddenMarkovModels.jl: Generic, Fast and Reliable State Space Modeling.” Journal of Open Source Software 9 (96): 6436. https://doi.org/10.21105/joss.06436.
Rabiner, L. R. 1989. “A Tutorial on Hidden Markov Models and Selected Applications in Speech Recognition.” Proceedings of the IEEE 77 (2): 257–86. https://doi.org/cswph2.
Von Hilgers, Philipp, and Amy N. Langville. 2006. “The Five Greatest Applications of Markov Chains.” In Proceedings of the Markov Anniversary Meeting.