f(x) = √x # computes sqrt
function g(x) # computes approximate sqrt
= x
y for i in 1:3
= 0.5 * (y + x/y)
y end
return y
end
A tale of two languages
2025-03-18
What is differentiation?
Finding a linear approximation of a function around a point.
Why do we care?
Derivatives of computer programs are essential in optimization and machine learning.
What do we need to do?
Not much: automatic differentiation (AD) computes derivatives for us!
Derivative of \(f: \mathbb{R}^n \to \mathbb{R}^m\) at point \(x\): linear map \(\partial f(x)\) such that \[f(x + \varepsilon) = f(x) + \partial f(x)[\varepsilon] + o(\varepsilon)\]
\[\partial f(x) = \left(\frac{\partial f_i}{\partial x_j} (x)\right)_{1 \leq i \leq n, 1 \leq j \leq m}\]
Write down formulas like you’re in high school.
Drawback
Labor-intensive, error-prone.
Ask Mathematica / Wolfram Alpha to work out formulas for you.
\[\begin{equation} 0.5 \left( 0.5 \left( 0.5 + \frac{1}{0.5 \left( 1 + x \right)} + \frac{ - 0.5 x}{0.25 \left( 1 + x \right)^{2}} \right) + \frac{1}{0.5 \left( 0.5 \left( 1 + x \right) + \frac{x}{0.5 \left( 1 + x \right)} \right)} - 0.5 \left( 0.5 + \frac{1}{0.5 \left( 1 + x \right)} + \frac{ - 0.5 x}{0.25 \left( 1 + x \right)^{2}} \right) \frac{x}{0.25 \left( 0.5 \left( 1 + x \right) + \frac{x}{0.5 \left( 1 + x \right)} \right)^{2}} \right) \end{equation}\]
Drawback
Does not scale to more complex functions.
Rely on finite differences with a small perturbation.
\[\partial f(x)[\varepsilon] \approx \frac{f(x + \varepsilon) - f(x)}{\varepsilon}\]
Drawback
Truncation or floating point errors depending on \(\varepsilon\).
Reinterpret the program computing \(f\) to obtain \(\partial f(x)\) instead.
Drawback
Hard to reinterpret arbitrary code efficiently.
\[ f = g \circ h \qquad \implies \qquad \partial f(x) = \partial g(h(x)) \circ \partial h(x)\]
Main implementation paradigms:
Operator overloading
Define new types augmenting runtime operations.
Source transformation
Preprocess the source code at compile time.
Consider \(f : x \in \mathbb{R}^n \longmapsto y \in \mathbb{R}^m\). Time \(T(f)\) = one evaluation of \(f\).
Forward mode
At cost \(\propto T(f)\), get all \(m\) partial derivatives wrt input \(x_i\).
Propagate an input perturbation onto the outputs.
Reverse mode
At cost \(\propto T(f)\), get all \(n\) partial sensitivities for output \(y_j\).
Backpropagate an output sensitivity onto the inputs.
Why is deep learning possible?
Because gradients in reverse mode are fast.
In Python, three main AD frameworks:
In Julia, a dozen or so AD backends:
Each backend has its use cases, especially for scientific ML.
Image: courtesy of Adrian Hill
Image: courtesy of Adrian Hill
Python | Julia | |
---|---|---|
Math & tensors | Framework-specific | Part of the core language |
AD development | Centralized (x3) | Decentralized |
Limits of AD | ✅ Well-defined | ❌ Fuzzy |
Scientific libraries | ❌ Split effort | ✅ Shared effort |
Does it have to be this way?
AD could be a language feature instead of a post-hoc addition.
DifferentiationInterface.jl
talks to all Julia AD backends.
Optimization.jl
, Turing.jl
, NonlinearSolve.jl
Only one repo containing AD bindings: easier to maintain and improve.
Switching backends is now instantaneous.
Preliminary work is abstracted away into a preparation step.
Having a common interface lets us do things we couldn’t do before:
using DifferentiationInterface, SparseConnectivityTracer, SparseMatrixColorings
import ForwardDiff, ReverseDiff
backend = AutoSparse(
SecondOrder(AutoForwardDiff(), AutoReverseDiff());
sparsity_detector=TracerSparsityDetector(), coloring_algorithm=GreedyColoringAlgorithm()
)
f(x) = sum(abs2, x)
x = float.(1:4)
p = prepare_hessian(f, backend, similar(x))
hessian(f, p, backend, x)
4×4 SparseArrays.SparseMatrixCSC{Float64, Int64} with 4 stored entries:
2.0 ⋅ ⋅ ⋅
⋅ 2.0 ⋅ ⋅
⋅ ⋅ 2.0 ⋅
⋅ ⋅ ⋅ 2.0
Computing derivatives is automatic and efficient.
Each AD system comes with limitations, learn to recognize them.
Julia is a great ecosystem to play around with AD.
Do you have a tricky AD problem?
Reach out to me, let’s figure it out! My website: gdalle.github.io