import jax.numpy as jnp
from jax import grad, jacfwd, jacrev, random
import time ## Just for timing how long calculations take
Automatic Differentiation with JAX
Install JAX
terminal
pip install jax
A Basic Function
Let’s take a basic function, \[ f(x) = \left[ \sin\left(\frac{x_0}{x_1}\right) + x_1^2, \quad x_0^2 - x_1^2, \quad x_0^3 \right] \]
We can define this in Python, using JAX for the mathematic operations:
= lambda x: jnp.array([ jnp.sin(x[0] / x[1]) + x[1]**2, x[0]**2 - x[1]**2 , x[0]**3 ]) f
Forward Mode
Then it is easy to compute the Jacobian using JAX’s jacfwd
function:
1, 0.5])) jacfwd(f)(jnp.array([
Array([[-0.8322937, 2.6645875],
[ 2. , -1. ],
[ 3. , 0. ]], dtype=float32)
Reverse Mode
And reverse mode is also easy with JAX’s jacrev
function:
1, 0.5])) jacrev(f)(jnp.array([
Array([[-0.8322937, 2.6645875],
[ 2. , -1. ],
[ 3. , 0. ]], dtype=float32)
We could have computed the Jacobian of this function symbolically and then evaluated it at the pont \((1, 0.5)\). But the advantage of AD is for more complicated functions.
Complicated Functions
Comparing AD Modes
Let’s compare the speed of forward and reverse mode AD.
We will define a function that goes from \(\mathbb{R}^{10000}\) to \(\mathbb{R}\).
\[ loss(x) = \sin \left( \exp \left( \frac{1}{\sqrt{n}} x^* \cdot x \right) \right) \] where \(x^*\) and \(x\) are vectors of length \(n = 10000\).
= random.normal(random.PRNGKey(42), 10000)
x_star = lambda x: jnp.sin( jnp.exp(jnp.dot(x_star, x) / jnp.sqrt(x.shape[0]))) loss
We can then compute the Jacobian of this function using either forward or reverse mode AD.
= random.normal(random.PRNGKey(1234), 10000)
x = jacfwd(loss)(x)
J = jacrev(loss)(x) J
Timing the two modes
Timing the forward AD, we get
= random.normal(random.PRNGKey(1234), 10000)
x = time.time()
start = jacfwd(loss)(x)
J = time.time()
end = end - start
fwd_time print("Forward mode time: ", fwd_time)
print(J)
Forward mode time: 0.22470307350158691
[-0.00112947 0.01864054 0.0117998 ... -0.01512874 -0.02109648
-0.01740325]
And timing the backward AD, we get
= random.normal(random.PRNGKey(1234), 10000)
x = time.time()
start = jacrev(loss)(x)
J = time.time()
end = end - start
rev_time print("Backward mode time: ", rev_time)
print(J)
Backward mode time: 0.002707958221435547
[-0.00112947 0.01864053 0.0117998 ... -0.01512874 -0.02109648
-0.01740325]
We can see that backward mode is much faster than forward mode for this function.1
1 JAX uses just-in-time compiling, so the first time you run the forward and backward diff functions, you won’t see a time difference, but each time after that you will.
print("About ", round(fwd_time / rev_time, 1), " times faster!")
About 83.0 times faster!
Control Loop Functions
One thing that we cannot solve symbolically are functions that depend upon a control loop (if-else, for, while).
This is, however, easy for AD to handle.
We will define a simple function that takes an input vector, then applies \(sin(x^2)\) for \(k\) iterations, where \(k\) is the maximum integer value of the input vector, and then sums the result.
This function can clearly be described algorithmically and written in code, but cannot be solved symbolically.
= random.normal(random.PRNGKey(1234), 100)
x
def f(z):
for k in range(int(max(z))):
= jnp.sin(z ** 2)
z return jnp.sum(z)
f(x)
Array(33.790833, dtype=float32)
= jacrev(f)(x)
J J
Array([ 9.14219797e-01, 1.54198611e+00, -1.54022560e-01, 1.28125024e+00,
-1.06963313e+00, 1.41801819e-01, 1.32393825e+00, -1.08029068e-01,
-2.01322168e-01, 7.39066958e-01, 1.55481815e+00, 2.13262364e-01,
7.93227702e-02, 1.38910860e-02, -1.21067815e-01, -2.39411616e+00,
-5.80047350e-03, 7.46430131e-03, 1.09903216e-02, 1.23930001e+00,
3.33788633e+00, 1.23865539e-04, 2.95916528e-01, 4.19792563e-01,
2.75377621e-04, -3.64146411e-01, 1.39376473e+00, -1.05361438e+00,
-3.64593752e-02, 6.60361871e-02, 8.11212718e-01, -2.80359697e+00,
4.31727380e-01, 7.41075158e-01, 1.13879979e+00, -2.50400219e-04,
-2.73637772e+00, 2.62407839e-01, 2.75387573e+00, 1.89503059e-01,
4.25044820e-02, 1.70111430e+00, -7.05717206e-01, -1.05721414e+00,
-8.64632353e-02, -9.74199712e-01, 2.29958575e-02, 2.32612395e+00,
1.27304599e-01, -9.37167764e-01, -3.26535940e+00, -6.87842488e-01,
-1.80836034e+00, 1.53212297e+00, 3.10942113e-01, 3.40831243e-02,
2.66577810e-01, -2.80669713e+00, 3.78264618e+00, 1.39550459e+00,
8.64577770e-01, -1.54589021e+00, 4.69316356e-03, -4.76004314e+00,
1.27595150e+00, 7.95055568e-01, -5.48990011e-01, 2.21338272e+00,
-2.89355803e+00, -8.49291608e-02, 1.24590242e+00, -1.40686893e+00,
2.06369385e-02, 4.42322403e-01, 3.01030427e-02, -8.12745478e-04,
-2.86575755e-06, -1.53197777e+00, -1.26122987e+00, -3.50974655e+00,
2.20662117e-01, 2.29297191e-01, -1.32497787e+00, 1.31583583e+00,
1.28999364e+00, -2.41961493e-03, 5.28108299e-01, 9.57041860e-01,
-8.39974359e-02, -1.60423791e+00, -1.59235686e-01, 1.62865722e+00,
1.21260905e+00, -6.51759235e-03, -9.24389601e-01, -5.34576876e-03,
1.41070215e-02, 7.33746052e-01, 2.06736708e+00, 2.36598169e-03], dtype=float32)
This is just one example of a function that AD makes it possible to compute the Jacobian/gradient of.