Automatic Differentiation with JAX

Authors

Matthew DeHaven

Blake Bordelon

Install JAX

JAX Documentation

terminal
pip install jax
import jax.numpy as jnp
from jax import grad, jacfwd, jacrev, random
import time ## Just for timing how long calculations take

A Basic Function

Let’s take a basic function, \[ f(x) = \left[ \sin\left(\frac{x_0}{x_1}\right) + x_1^2, \quad x_0^2 - x_1^2, \quad x_0^3 \right] \]

We can define this in Python, using JAX for the mathematic operations:

f = lambda x: jnp.array([ jnp.sin(x[0] / x[1]) + x[1]**2,  x[0]**2 - x[1]**2 , x[0]**3 ])

Forward Mode

Then it is easy to compute the Jacobian using JAX’s jacfwd function:

jacfwd(f)(jnp.array([1, 0.5]))
Array([[-0.8322937,  2.6645875],
       [ 2.       , -1.       ],
       [ 3.       ,  0.       ]], dtype=float32)

Reverse Mode

And reverse mode is also easy with JAX’s jacrev function:

jacrev(f)(jnp.array([1, 0.5]))
Array([[-0.8322937,  2.6645875],
       [ 2.       , -1.       ],
       [ 3.       ,  0.       ]], dtype=float32)

We could have computed the Jacobian of this function symbolically and then evaluated it at the pont \((1, 0.5)\). But the advantage of AD is for more complicated functions.

Complicated Functions

Comparing AD Modes

Let’s compare the speed of forward and reverse mode AD.

We will define a function that goes from \(\mathbb{R}^{10000}\) to \(\mathbb{R}\).

\[ loss(x) = \sin \left( \exp \left( \frac{1}{\sqrt{n}} x^* \cdot x \right) \right) \] where \(x^*\) and \(x\) are vectors of length \(n = 10000\).

x_star = random.normal(random.PRNGKey(42), 10000)
loss = lambda x: jnp.sin( jnp.exp(jnp.dot(x_star, x) / jnp.sqrt(x.shape[0])))

We can then compute the Jacobian of this function using either forward or reverse mode AD.

x = random.normal(random.PRNGKey(1234), 10000)
J = jacfwd(loss)(x)
J = jacrev(loss)(x)
Timing the two modes

Timing the forward AD, we get

x = random.normal(random.PRNGKey(1234), 10000)
start = time.time()
J = jacfwd(loss)(x)
end = time.time()
fwd_time = end - start
print("Forward mode time: ", fwd_time)
print(J)
Forward mode time:  0.22470307350158691
[-0.00112947  0.01864054  0.0117998  ... -0.01512874 -0.02109648
 -0.01740325]

And timing the backward AD, we get

x = random.normal(random.PRNGKey(1234), 10000)
start = time.time()
J = jacrev(loss)(x)
end = time.time()
rev_time = end - start
print("Backward mode time: ", rev_time)
print(J)
Backward mode time:  0.002707958221435547
[-0.00112947  0.01864053  0.0117998  ... -0.01512874 -0.02109648
 -0.01740325]

We can see that backward mode is much faster than forward mode for this function.1

1 JAX uses just-in-time compiling, so the first time you run the forward and backward diff functions, you won’t see a time difference, but each time after that you will.

print("About ", round(fwd_time / rev_time, 1), " times faster!")
About  83.0  times faster!

Control Loop Functions

One thing that we cannot solve symbolically are functions that depend upon a control loop (if-else, for, while).

This is, however, easy for AD to handle.

We will define a simple function that takes an input vector, then applies \(sin(x^2)\) for \(k\) iterations, where \(k\) is the maximum integer value of the input vector, and then sums the result.

This function can clearly be described algorithmically and written in code, but cannot be solved symbolically.

x = random.normal(random.PRNGKey(1234), 100)

def f(z):
  for k in range(int(max(z))):
    z = jnp.sin(z ** 2)
  return jnp.sum(z)


f(x)
Array(33.790833, dtype=float32)
J = jacrev(f)(x)
J
Array([ 9.14219797e-01,  1.54198611e+00, -1.54022560e-01,  1.28125024e+00,
       -1.06963313e+00,  1.41801819e-01,  1.32393825e+00, -1.08029068e-01,
       -2.01322168e-01,  7.39066958e-01,  1.55481815e+00,  2.13262364e-01,
        7.93227702e-02,  1.38910860e-02, -1.21067815e-01, -2.39411616e+00,
       -5.80047350e-03,  7.46430131e-03,  1.09903216e-02,  1.23930001e+00,
        3.33788633e+00,  1.23865539e-04,  2.95916528e-01,  4.19792563e-01,
        2.75377621e-04, -3.64146411e-01,  1.39376473e+00, -1.05361438e+00,
       -3.64593752e-02,  6.60361871e-02,  8.11212718e-01, -2.80359697e+00,
        4.31727380e-01,  7.41075158e-01,  1.13879979e+00, -2.50400219e-04,
       -2.73637772e+00,  2.62407839e-01,  2.75387573e+00,  1.89503059e-01,
        4.25044820e-02,  1.70111430e+00, -7.05717206e-01, -1.05721414e+00,
       -8.64632353e-02, -9.74199712e-01,  2.29958575e-02,  2.32612395e+00,
        1.27304599e-01, -9.37167764e-01, -3.26535940e+00, -6.87842488e-01,
       -1.80836034e+00,  1.53212297e+00,  3.10942113e-01,  3.40831243e-02,
        2.66577810e-01, -2.80669713e+00,  3.78264618e+00,  1.39550459e+00,
        8.64577770e-01, -1.54589021e+00,  4.69316356e-03, -4.76004314e+00,
        1.27595150e+00,  7.95055568e-01, -5.48990011e-01,  2.21338272e+00,
       -2.89355803e+00, -8.49291608e-02,  1.24590242e+00, -1.40686893e+00,
        2.06369385e-02,  4.42322403e-01,  3.01030427e-02, -8.12745478e-04,
       -2.86575755e-06, -1.53197777e+00, -1.26122987e+00, -3.50974655e+00,
        2.20662117e-01,  2.29297191e-01, -1.32497787e+00,  1.31583583e+00,
        1.28999364e+00, -2.41961493e-03,  5.28108299e-01,  9.57041860e-01,
       -8.39974359e-02, -1.60423791e+00, -1.59235686e-01,  1.62865722e+00,
        1.21260905e+00, -6.51759235e-03, -9.24389601e-01, -5.34576876e-03,
        1.41070215e-02,  7.33746052e-01,  2.06736708e+00,  2.36598169e-03],      dtype=float32)

This is just one example of a function that AD makes it possible to compute the Jacobian/gradient of.