Gradient Descent in HMMs
#=
In this tutorial we explore two ways to use gradient descent when fitting HMMs:
1. Fitting parameters of an observation model that do not have closed-form updates
(e.g., GLMs, neural networks, etc.), inside the EM algorithm.
2. Fitting the entire HMM with gradient-based optimization by leveraging automatic
differentiation.
We will explore both approaches below.
=#
using ADTypes
using ComponentArrays
using DensityInterface
using ForwardDiff
using HiddenMarkovModels
using LinearAlgebra
using Optim
using Random
using StableRNGs
using StatsAPI
rng = StableRNG(42)
#=
For both parts of this tutorial we use a simple HMM with Gaussian observations.
Using gradient-based optimization here is overkill, but it keeps the tutorial
simple while illustrating the relevant methods.
We begin by defining a Normal observation model.
=#
mutable struct NormalModel{T}
μ::T
logσ::T # unconstrained parameterization; σ = exp(logσ)
end
model_mean(mod::NormalModel) = mod.μ
stddev(mod::NormalModel) = exp(mod.logσ)
#=
We have defined a simple probability model with two parameters: the mean and the
log of the standard deviation. Using `logσ` is intentional so we can optimize over
all real numbers without worrying about the positivity constraint on `σ`.
Next, we provide the minimal interface expected by HiddenMarkovModels.jl:
`(logdensityof, rand, fit!)`.
=#
function DensityInterface.logdensityof(mod::NormalModel, obs::T) where {T<:Real}
s = stddev(mod)
return - log(2π) / 2 - log(s) - ((obs - model_mean(mod)) / s)^2 / 2
end
DensityInterface.DensityKind(::NormalModel) = DensityInterface.HasDensity()
function Random.rand(rng::AbstractRNG, mod::NormalModel{T}) where {T}
return stddev(mod) * randn(rng, T) + model_mean(mod)
end
#=
Because we are fitting a Gaussian (and the variance can collapse to ~0), we add
weak priors to regularize the parameters. We use:
- A weak Normal prior on `μ`
- A moderate-strength Normal prior on `logσ` that pulls `σ` toward ~1
=#
const μ_prior = NormalModel(0.0, log(10.0))
const logσ_prior = NormalModel(log(1.0), log(0.5))
function neglogpost(
μ::T,
logσ::T,
data::AbstractVector{<:Real},
weights::AbstractVector{<:Real},
μ_prior::NormalModel,
logσ_prior::NormalModel,
) where {T<:Real}
tmp = NormalModel(μ, logσ)
nll = mapreduce(
i -> -weights[i] * logdensityof(tmp, data[i]), +, eachindex(data, weights)
)
nll += -logdensityof(μ_prior, μ)
nll += -logdensityof(logσ_prior, logσ)
return nll
end
function neglogpost(
θ::AbstractVector{T},
data::AbstractVector{<:Real},
weights::AbstractVector{<:Real},
μ_prior::NormalModel,
logσ_prior::NormalModel,
) where {T<:Real}
μ, logσ = θ
return neglogpost(μ, logσ, data, weights, μ_prior, logσ_prior)
end
function StatsAPI.fit!(
mod::NormalModel, data::AbstractVector{<:Real}, weights::AbstractVector{<:Real}
)
T = promote_type(typeof(mod.μ), typeof(mod.logσ))
θ0 = T[T(mod.μ), T(mod.logσ)]
obj = θ -> neglogpost(θ, data, weights, μ_prior, logσ_prior)
result = Optim.optimize(obj, θ0, BFGS(); autodiff=AutoForwardDiff())
mod.μ, mod.logσ = Optim.minimizer(result)
return mod
end
#=
Now that we have fully defined our observation model, we can create an HMM using it.
=#
init_dist = [0.2, 0.7, 0.1]
init_trans = [
0.9 0.05 0.05;
0.075 0.9 0.025;
0.1 0.1 0.8
]
obs_dists = [
NormalModel(-3.0, log(0.25)), NormalModel(0.0, log(0.5)), NormalModel(3.0, log(0.75))
]
hmm_true = HMM(init_dist, init_trans, obs_dists)
#=
We can now generate data from this HMM.
Note: `rand(rng, hmm, T)` returns `(state_seq, obs_seq)`.
=#
state_seq, obs_seq = rand(rng, hmm_true, 10_000)
#=
Next we fit a new HMM to this data. Baum–Welch will perform EM updates for the
HMM parameters; during the M-step, our observation model parameters are fit via
gradient-based optimization (BFGS).
=#
init_dist_guess = fill(1.0 / 3, 3)
init_trans_guess = [
0.98 0.01 0.01;
0.01 0.98 0.01;
0.01 0.01 0.98
]
obs_dist_guess = [
NormalModel(-2.0, log(1.0)), NormalModel(2.0, log(1.0)), NormalModel(0.0, log(1.0))
]
hmm_guess = HMM(init_dist_guess, init_trans_guess, obs_dist_guess)
hmm_est, lls = baum_welch(hmm_guess, obs_seq)
#=
Great! We were able to fit the model using gradient descent inside EM.
Now we will fit the entire HMM using gradient-based optimization by leveraging
automatic differentiation. The key idea is that the forward algorithm marginalizes
out the latent states, providing the likelihood of the observations directly as a
function of all model parameters.
We can therefore optimize the negative log-likelihood returned by `forward`.
Each objective evaluation runs the forward algorithm, which can be expensive for
large datasets, but this approach allows end-to-end gradient-based fitting for
arbitrary parameterized HMMs.
To respect HMM constraints, we optimize unconstrained parameters and map them to
valid probability distributions via softmax:
- `π = softmax(ηπ)`
- each row of `A` = `softmax(row_logits)`
=#
function softmax(v::AbstractVector)
m = maximum(v)
ex = exp.(v .- m)
return ex ./ sum(ex)
end
function rowsoftmax(M::AbstractMatrix)
A = similar(M)
for i in 1:size(M, 1)
A[i, :] .= softmax(view(M, i, :))
end
return A
end
function unpack_to_hmm(θ::ComponentVector)
K = length(θ.ηπ)
π = softmax(θ.ηπ)
A = rowsoftmax(θ.ηA)
dists = [NormalModel(θ.μ[k], θ.logσ[k]) for k in 1:K]
return HMM(π, A, dists)
end
function hmm_to_θ0(hmm::HMM)
K = length(hmm.init)
T = promote_type(
eltype(hmm.init),
eltype(hmm.trans),
eltype(hmm.dists[1].μ),
eltype(hmm.dists[1].logσ),
)
ηπ = log.(hmm.init .+ eps(T))
ηA = log.(hmm.trans .+ eps(T))
μ = [hmm.dists[k].μ for k in 1:K]
logσ = [hmm.dists[k].logσ for k in 1:K]
return ComponentVector(; ηπ=ηπ, ηA=ηA, μ=μ, logσ=logσ)
end
function negloglik_from_θ(θ::ComponentVector, obs_seq)
hmm = unpack_to_hmm(θ)
_, loglik = forward(hmm, obs_seq; error_if_not_finite=false)
return -loglik[1]
end
θ0 = hmm_to_θ0(hmm_guess)
ax = getaxes(θ0)
obj(x) = negloglik_from_θ(ComponentVector(x, ax), obs_seq)
result = Optim.optimize(obj, Vector(θ0), BFGS(); autodiff=AutoForwardDiff())
hmm_est2 = unpack_to_hmm(ComponentVector(result.minimizer, ax))
#=
We have now trained an HMM using gradient-based optimization over *all* parameters!
=#Hidden Markov Model with:
- initialization: [9.604250487467529e-20, 8.456720664851089e-20, 1.0]
- transition matrix: [0.9009541776266187 0.05234521658888211 0.04670060578449929; 0.09220927064080585 0.8101851152960515 0.0976056140631426; 0.07177945535627225 0.02264871266148299 0.9055718319822449]
- observation distributions: [Main.NormalModel{Float64}(-2.998740427485348, -1.3737153743541302), Main.NormalModel{Float64}(3.032018209433466, -0.31622032685239504), Main.NormalModel{Float64}(0.0056605689909786094, -0.7017072704048756)]This page was generated using Literate.jl.