Advanced use cases

We dive into more advanced applications of implicit differentiation:

  • constrained optimization problems
using ForwardDiff
using ImplicitDifferentiation
using LinearAlgebra
using Optim
using Random
using Zygote

Random.seed!(63);

Constrained optimization

First, we show how to differentiate through the solution of a constrained optimization problem:

\[y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) \quad \text{subject to} \quad g(x, y) \leq 0\]

The optimality conditions are a bit trickier than in the previous cases. We can projection on the feasible set $\mathcal{C}(x) = \{y: g(x, y) \leq 0 \}$ and exploit the convergence of projected gradient descent with step size $\eta$:

\[y = \mathrm{proj}_{\mathcal{C}(x)} (y - \eta \nabla_2 f(x, y))\]

To make verification easy, we minimize the following objective:

\[f(x, y) = \lVert y \odot y - x \rVert^2\]

on the hypercube $\mathcal{C}(x) = [0, 1]^n$. In this case, the optimization problem boils down to a thresholded componentwise square root function, but we implement it using a black box solver from Optim.jl.

function forward_cstr_optim(x)
    f(y) = sum(abs2, y .^ 2 - x)
    lower = zeros(size(x))
    upper = ones(size(x))
    y0 = ones(eltype(x), size(x)) ./ 2
    res = optimize(f, lower, upper, y0, Fminbox(GradientDescent()))
    y = Optim.minimizer(res)
    return y
end
forward_cstr_optim (generic function with 1 method)
function proj_hypercube(p)
    return max.(0, min.(1, p))
end

function conditions_cstr_optim(x, y)
    ∇₂f = @. 4 * (y^2 - x) * y
    η = 0.1
    return y .- proj_hypercube(y .- η .* ∇₂f)
end
conditions_cstr_optim (generic function with 1 method)

We now have all the ingredients to construct our implicit function.

implicit_cstr_optim = ImplicitFunction(forward_cstr_optim, conditions_cstr_optim)
ImplicitFunction(forward_cstr_optim, conditions_cstr_optim, IterativeLinearSolver(true, false), nothing)

And indeed, it behaves as it should when we call it:

x = rand(2) .+ [0, 1]
2-element Vector{Float64}:
 0.22442135286865494
 1.3267275094228514

The second component of $x$ is $> 1$, so its square root will be thresholded to one, and the corresponding derivative will be $0$.

implicit_cstr_optim(x) .^ 2
2-element Vector{Float64}:
 0.22442135286146742
 0.9999999995782778
J_thres = Diagonal([0.5 / sqrt(x[1]), 0])
2×2 LinearAlgebra.Diagonal{Float64, Vector{Float64}}:
 1.05545   ⋅ 
  ⋅       0.0

Forward mode autodiff

ForwardDiff.jacobian(implicit_cstr_optim, x)
2×2 Matrix{Float64}:
 1.05545  0.0
 0.0      0.0
ForwardDiff.jacobian(forward_cstr_optim, x)
2×2 Matrix{Float64}:
 1.07645     1.04174
 4.07705e-7  2.02198e-5

Reverse mode autodiff

Zygote.jacobian(implicit_cstr_optim, x)[1]
2×2 Matrix{Float64}:
  1.05545  -0.0
 -0.0       0.0
try
    Zygote.jacobian(forward_cstr_optim, x)[1]
catch e
    e
end
Zygote.CompileError(Tuple{typeof(Optim.optimize), NLSolversBase.OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, Vector{Float64}, Vector{Float64}, Vector{Float64}, Optim.Fminbox{Optim.GradientDescent{LineSearches.InitialPrevious{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Optim.var"#13#15"}, Float64, Optim.var"#49#51"}, Optim.Options{Float64, Nothing}}, ErrorException("try/catch is not supported.\nRefer to the Zygote documentation for fixes.\nhttps://fluxml.ai/Zygote.jl/latest/limitations\n"))

This page was generated using Literate.jl.