Autodiff
Here we show how to compute gradients of the observation sequence loglikelihood with respect to various inputs.
using ComponentArrays
using DensityInterface
using Distributions
using Enzyme: Enzyme
using ForwardDiff: ForwardDiff
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using LinearAlgebra
using Random: Random, AbstractRNG
using StableRNGs
using StatsAPI
using Zygote: Zygote
rng = StableRNG(63);
Diffusion HMM
To play around with automatic differentiation, we define a simple controlled HMM.
struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <: AbstractHMM
init::V1
trans::M2
means::V3
end
Both its transition matrix and its vector of observation means result from a convex combination between the corresponding field and a base value (aka diffusion). The coefficient $\lambda$ of this convex combination is given as a control.
HMMs.initialization(hmm::DiffusionHMM) = hmm.init
function HMMs.transition_matrix(hmm::DiffusionHMM, λ::Number)
N = length(hmm)
return (1 - λ) * hmm.trans + λ * ones(N, N) / N
end
function HMMs.obs_distributions(hmm::DiffusionHMM, λ::Number)
return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:length(hmm)]
end
We now construct an instance of this object and draw samples from it.
init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
means = [-1.0, 1.0]
hmm = DiffusionHMM(init, trans, means);
It is essential that the controls are taken between $0$ and $1$.
control_seqs = [rand(rng, 3), rand(rng, 5)];
obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in 1:2];
control_seq = reduce(vcat, control_seqs)
obs_seq = reduce(vcat, obs_seqs)
seq_ends = cumsum(length.(obs_seqs));
What to differentiate?
The key function we are interested in is the loglikelihood of the observation sequence. We can differentiate it with respect to
- the model itself (
hmm
), or more precisely its parameters - the observation sequence (
obs_seq
) - the control sequence (
control_seq
). - but not with respect to the sequence limits (
seq_ends
), which are discrete.
logdensityof(hmm, obs_seq, control_seq; seq_ends)
-11.140885085795315
To ensure compatibility with backends that only accept a single input, we wrap all parameters inside a ComponentVector
from ComponentArrays.jl, and define a new function to differentiate.
parameters = ComponentVector(; init, trans, means)
function f(parameters::ComponentVector, obs_seq, control_seq; seq_ends)
new_hmm = DiffusionHMM(parameters.init, parameters.trans, parameters.means)
return logdensityof(new_hmm, obs_seq, control_seq; seq_ends)
end;
f(parameters, obs_seq, control_seq; seq_ends)
-11.140885085795315
Forward mode
Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with ForwardDiff.jl.
Because ForwardDiff.jl only accepts a single input, we must compute derivatives one at a time.
∇parameters_forwarddiff = ForwardDiff.gradient(
_parameters -> f(_parameters, obs_seq, control_seq; seq_ends), parameters
)
ComponentVector{Float64}(init = [1.6015041005889752, 2.597743849116537], trans = [2.0231939522847693 2.6484057284482456; 1.846962916306125 1.729289186994111], means = [1.635827044922573, -1.1262789699236988])
∇obs_forwarddiff = ForwardDiff.gradient(
_obs_seq -> f(parameters, _obs_seq, control_seq; seq_ends), obs_seq
)
8-element Vector{Float64}:
-1.3502415694411747
0.37694083963154223
-0.3213232237304544
-0.1376825076711971
-0.4787984769521781
0.19099878535867817
-0.2910677639854282
0.7013121919078473
∇control_forwarddiff = ForwardDiff.gradient(
_control_seq -> f(parameters, obs_seq, _control_seq; seq_ends), control_seq
)
8-element Vector{Float64}:
0.09044056663003254
0.7965904932728191
0.1850595346314735
0.7868458783447377
0.6174974309242304
0.6625459512151967
0.6438308770964614
0.11579151687878841
These values will serve as ground truth when we compare with reverse mode.
Reverse mode with Zygote.jl
In the presence of many parameters, reverse mode automatic differentiation of the loglikelihood will be much more efficient. The package includes a handwritten chain rule for logdensityof
, which means backends like Zygote.jl can be used out of the box. Using it, we can compute all derivatives at once.
∇all_zygote = Zygote.gradient(
(_a, _b, _c) -> f(_a, _b, _c; seq_ends), parameters, obs_seq, control_seq
);
∇parameters_zygote, ∇obs_zygote, ∇control_zygote = ∇all_zygote;
We can check the results to validate our chain rule.
∇parameters_zygote ≈ ∇parameters_forwarddiff
true
∇obs_zygote ≈ ∇obs_forwarddiff
true
∇control_zygote ≈ ∇control_forwarddiff
true
Reverse mode with Enzyme.jl
The more efficient Enzyme.jl also works natively as long as there are no type instabilities, which is why we avoid the closure and the keyword arguments with f_aux
:
function f_aux(parameters, obs_seq, control_seq, seq_ends)
return f(parameters, obs_seq, control_seq; seq_ends)
end
f_aux (generic function with 1 method)
Enzyme.jl requires preallocated storage for the gradients, which we happily provide.
∇parameters_enzyme = Enzyme.make_zero(parameters)
∇obs_enzyme = Enzyme.make_zero(obs_seq)
∇control_enzyme = Enzyme.make_zero(control_seq);
The syntax is a bit more complex, see the Enzyme.jl docs for details.
try
Enzyme.autodiff(
Enzyme.Reverse,
f_aux,
Enzyme.Active,
Enzyme.Duplicated(parameters, ∇parameters_enzyme),
Enzyme.Duplicated(obs_seq, ∇obs_enzyme),
Enzyme.Duplicated(control_seq, ∇control_enzyme),
Enzyme.Const(seq_ends),
)
catch exception # latest release of Enzyme broke this code
display(exception)
end
((nothing, nothing, nothing, nothing),)
Once again we can check the results.
∇parameters_enzyme ≈ ∇parameters_forwarddiff
true
∇obs_enzyme ≈ ∇obs_forwarddiff
true
∇control_enzyme ≈ ∇control_forwarddiff
true
For increased efficiency, we could provide temporary storage to Enzyme.jl in order to avoid allocations. This requires going one level deeper and leveraging the in-place HiddenMarkovModels.forward!
function.
Gradient methods
Once we have gradients of the loglikelihood, it is a natural idea to perform gradient descent in order to fit the parameters of a custom HMM. However, there are two caveats we must keep in mind.
First, computing a gradient essentially requires running the forward-backward algorithm, which means it is expensive. Given the output of forward-backward, if there is a way to perform a more accurate parameter update (like going straight to the maximum likelihood value), it is probably worth it. That is what we show in the other tutorials with the reimplementation of the fit!
method.
Second, HMM parameters live in a constrained space, which calls for a projected gradient descent. Most notably, the transition matrix must be stochastic, and the orthogonal projection onto this set (the Birkhoff polytope) is not easy to obtain.
Still, first order optimization can be relevant when we lack explicit formulas for maximum likelihood.
This page was generated using Literate.jl.