## 1. Setup

### 1.1. About this Text

I created this as a slide-based presentation for the Helmholtz AI Food for Thought seminar in order to introduce researchers from various backgrounds to JAX. This text was produced from the same Org source, with some extra commentary text interleaved. That’s why the text may feel choppy at times and code snippets are more compressed instead of following good style. I tried my best to explain everything I mentioned in the presentation as well.

Also, this is the “extended edition”, including many more code snippets as well as some behind-the-scenes stuff. This way, everything should be reproducible and you should have a good base reference in case you decide to pick up JAX.

You can download the original slides here.

### 1.2. Environment Setup

This is the environment I used. I gave specific version numbers as comments in case you need the better reproducibility.

Sadly, I was unable to compile JAX with GPU support; the version of CUDA the official NVIDIA drivers support for Ubuntu 20.04 is no longer supported by JAX.

That’s why all numbers I will show you need to be taken with a grain of salt – we aren’t able to use JAX’ killer feature!

conda create -p env python=3.8 conda activate ./env # `pytorch` version: 1.10.0 conda install pytorch cudatoolkit=10.2 -c pytorch python -m pip install --upgrade pip # `jax` version: 0.2.25 # `jaxlib` version: 0.1.74 python -m pip install --upgrade jax jaxlib # I only have the CPU version.

```
clang --version | sed 's/ *$//'
```

## 2. Overview

##### Topic for Today

- JAX is a
**cool**new corporation-backed framework for differentiable programming/scientific computing. - Faster than NumPy/SciPy due to GPU usage and…
- Compilability via Accelerated Linear Algebra compiler (XLA, reused from TensorFlow).
- More usability over NumPy and PyTorch.
- Due to “forced” functional style, get good code for free!
**Everything I will show you ran on the CPU!**

import functools import os import time import timeit import jax import jax.numpy as jnp import numpy as np import torch

##### Behind-the-Scenes Setup

The following setup code can safely be ignored (skip ahead) but achieves the following:

- XLA IR output in directory
`generated`

. - 2 simulated devices for XLA.
- Activate/disable GPUs in both JAX and PyTorch.
- Less print output.
- Initialize JAX and PyTorch so we don’t see warnings later.

##### Why is it Cool?

- NumPy/SciPy on the GPU.
- No mutable state means much joy: easier maintainability, easier scalability.
- No need for manual batching.
- Complex number differentiation
*previously*^{1}not possible on PyTorch. Certain unsupported cases may still be out there. - User-friendly parallelism.

##### JAX’ Libraries

`jax`

- Mostly important for function transformations/compositions.
`jax.numpy`

- NumPy replacement.
`jax.scipy`

- SciPy replacement.
`jax.random`

- Deterministic randomness.
`jax.lax`

- Lower-level operations.
`jax.profiler`

- Analyze performance.

Just a selection, there are more!

## 3. Mutating Immutable Arrays

##### Mutating Arrays in NumPy

In NumPy:

x = np.eye(2) x[0, 0] = 2 x[:, -1] *= -3 x

2 | 0 |

0 | -3 |

Some things to note for those not familiar with Python and/or NumPy:

- In Python, indices start at 0, so the first element of an array is at index 0.
- NumPy and JAX call any \( n \)-dimensional tensor an
*array*, so do not expect a 1-dimensional vector when you see the word “array”. - The colon “:” we used for indexing selects all elements over that dimension. In our case, this resulted in a whole row of the matrix being indexed.
- Negative values index from the end (but, confusingly, starting at 1 this time). So indexing with -1 yields the very last value of an array. Combining the colon with -1 in that order means we index the last row of the array.
- The
*shape*of an array is a collection of its dimensionalities (`x`

’s shape is`(2, 2)`

). - The
*dtype*of an array is the type of its elements.

##### Mutating Arrays in JAX

In JAX:

x = jnp.eye(2) try: x[0, 0] = 2 except TypeError: pass else: raise TypeError('array is mutable') x = x.at[0, 0].set(2) x = x.at[:, -1].multiply(-3) x

2 | 0 |

0 | -3 |

You will receive helpful errors in case you forget this!

The `try-except-else`

block of the code has the following effect: if
the statement in the `try`

block `x[0, 0] = 2`

raises a `TypeError`

,
we simply ignore it. That’s what the `except TypeError`

block does. If
we *did not get any error*, i.e. no exception was caught,
the `else`

block executes and raises a `TypeError`

telling us that the
array was mutable after all. Since the code executed and evaluated
just fine, we know that JAX raised an error upon trying to mutate the
array, meaning JAX arrays are not mutable as expected. (Whew!)

You may think performance tanks due to the extra allocations; however, these are optimized away to in-place mutations during JIT (just-in-time) compilation.

##### Closures (Functional Programming Interlude)

Function capturing something from an outer environment. (You’ve been using this whenever you refer to a non-local, e.g. global, variable inside a function.) Very useful for programming with JAX, but be careful about mutable state!

def create_counter(): count = 0 def inc(): nonlocal count count += 1 return count return inc counter = create_counter() ', '.join(map(str, [counter(), counter(), counter()]))

1, 2, 3 |

If you don’t know Python: the `def`

indicates a function definition.

In Python, closure capturing rules are the same as for function
argument passing, also called “pass-by-sharing” – non-primitives are
*referenced* (imagine a C pointer on that object). So outside
modifications to non-primitives will be visible inside the closure as
well!

This behavior varies between programming languages, so keep in mind that this is a Python implementation detail.

## 4. Randomness in JAX

##### Randomness in JAX

Stateless RNG makes reproducibility fun and easy as pie!^{2}

seed = 0 rng_key = jax.random.PRNGKey(seed) def randn(rng_key, shape, dtype=float): rng_key, subkey = jax.random.split(rng_key) rands = jax.random.normal(subkey, shape, dtype=dtype) return rng_key, rands rng_key, rands = randn(rng_key, (2, 3)) rands

-1.458 | -2.047 | -1.424 |

1.168 | -0.976 | -1.272 |

RNG keys are the way JAX represents RNG state. We called the key we
use right away `subkey`

even though it is the exact same kind of
object as `rng_key`

. This naming is simply a nice pattern to use
because we know from the names which keys we are going to “consume”
and which we will pass along.

You can also split off more than one key by passing an additional
integer to `jax.random.split`

.

Another thing I wanted to show off here is the pattern of using
standard Python types with JAX. Since we only specified `float`

instead of the more specific `jnp.float32`

or `jnp.float64`

, JAX will
automatically pick the correct one based on its configured default.

##### RNG Keys

_, rands = randn(rng_key, (2, 3)) rands

0.932 | -0.004 | -0.74 |

-1.427 | 1.06 | 0.929 |

Notice we did not update our `rng_key`

, instead discarding that part
of the result. Can you guess what happens when we generate numbers
again?

rng_key, rands = randn(rng_key, (2, 3)) rands

0.932 | -0.004 | -0.74 |

-1.427 | 1.06 | 0.929 |

This time, we updated our `rng_key`

: the world of randomness is whole
again! Notice this is a really easy way to produce random numbers that
you (hopefully) want to stay the same.

rng_key, rands = randn(rng_key, (2, 3)) rands

1.169 | 0.312 | -0.571 |

0.137 | 0.742 | 0.038 |

## 5. `grad`

: Advanced AutoDiff

### 5.1. `grad`

How-To

##### How to Take Gradients

JAX’ `grad`

function transformation takes a function we want to
differentiate as input and returns another function calculating the
original one’s gradient given arbitrary inputs. By default, it takes
the derivative with regard to the first argument. Multiple `jax.grad`

applications can be nested to take higher-order derivatives.

Another Python syntax thing here, the `**`

is the exponentiation
operator. If you’re curious, `^`

is used for a bitwise exclusive or.

def expo_fn(x, y): return x**4 + 2**y + 3 x = 1.0 y = 2.0 grad_x_fn = jax.grad(expo_fn) # 4x^3 grad2_x_fn = jax.grad(grad_x_fn) # 12x^2 grad3_x_fn = jax.grad(grad2_x_fn) # 24x [ ['function', r'\partial{}expo_fn/\partial{}\(x\)'], ['grad_x', grad_x_fn(x, y).item()], ['grad2_x', grad2_x_fn(x, y).item()], ['grad3_x', grad3_x_fn(x, y).item()], ]

function | ∂expo_fn/∂\(x\) |
---|---|

`grad_x` |
4.0 |

`grad2_x` |
12.0 |

`grad3_x` |
24.0 |

##### Differentiating Different Arguments

To differentiate with regard to other arguments of our functions, we
pass `jax.grad`

the indices of those arguments in the `argnums`

argument. We can also specify multiple `argnums`

.

x = 1.0 y = 2.0 grad_y_fn = jax.grad(expo_fn, argnums=1) # ln(y) · 2^y grad_xy_fn = jax.grad(expo_fn, argnums=(0, 1)) [ ['function', 'result'], ['grad_y', grad_y_fn(x, y).item()], ['grad_xy', [g.item() for g in grad_xy_fn(x, y)]], ]

function | result |
---|---|

`grad_y` |
2.7725887298583984 |

`grad_xy` |
(4.0 2.7725887298583984) |

##### Differentiation and Outputs

In machine learning, we really like to monitor our *loss values*,
which are the non-differentiated results of the function we
differentiate (i.e. the value we want to minimize). In
order to not have to evaluate the function twice and lose precious
performance, JAX offers the `jax.value_and_grad`

function
transformation which returns a new function that calculates the result
of the function as well as its gradient. Now we can log our losses and
sleep well again.

result_grad_fn = jax.value_and_grad(expo_fn) result, grad = result_grad_fn(x, y)

Fun fact: `jax.value_and_grad`

is actually what `jax.grad`

calls as
well, it just tosses `result`

.

Let’s assume the function we want to differentiate has multiple outputs, for example maybe we need to return some new state!

Let’s assume we went through the trouble of collecting all our extra return values in a tuple. We then also changed our function to return a pair (i.e. another tuple) containing (in this order)

- the value we want to differentiate through, and
- the tuple of extra return values.

We can then supply the `has_aux=True`

argument to `jax.grad`

and
happily differentiate again while keeping our state intact:

def poly_fn_and_aux(x, y): aux_output = ({'y': y}, 1337) return x**4 + (y - 1)**2 + 3, aux_output grad_aux_fn = jax.grad(poly_fn_and_aux, has_aux=True) grad, aux = grad_aux_fn(x, y)

Of course, the same works for `jax.value_and_grad`

as well; however,
its tree of return values needs some special care to deconstruct:

result_aux_grad_fn = jax.value_and_grad(poly_fn_and_aux, has_aux=True) (result, aux), grad = result_aux_grad_fn(x, y)

By the way, if you ever want to disable gradient computation inside a
`jax.grad`

context, you can use `jax.lax.stop_gradient`

. Its use is a
bit unintuitive, so I’d recommend checking out the link.

### 5.2. Differentiating Spectral Radius

#### Default Precision Interlude

##### Simple Spectral Radius

def jax_spectral_radius(mat): eigvals = jnp.linalg.eigvals(mat) spectral_radius = jnp.max(jnp.abs(eigvals)) return spectral_radius def torch_spectral_radius(mat): eigvals = torch.linalg.eigvals(mat) spectral_radius = torch.max(torch.abs(eigvals)) return spectral_radius # Eigenvalues: 1 ± i ceig_mat = np.array([[1.0, -1.0], [1.0, 1.0]]) jax_mat = jnp.array(ceig_mat) torch_mat = torch.from_numpy(ceig_mat) [ ['function', 'result'], ['jax', jax_spectral_radius(jax_mat).item()], ['torch', torch_spectral_radius(torch_mat).item()], ['sqrt', np.sqrt(2)], ]

##### JAX’ Default Precision

function | result |
---|---|

`jax` |
1.4142135381698608 |

`torch` |
1.4142135623730951 |

`sqrt` |
1.4142135623730951 |

Wait! JAX’ precision is seriously behind PyTorch here! Is PyTorch just more precise or what’s going on?

While both JAX and PyTorch have single-precision (32-bit) floating
point numbers as their default `dtype`

, NumPy uses double-precision
(64-bit) floats by default.

Now, when we converted the matrix from NumPy to the respective
frameworks, JAX’ `jnp.array`

created a new array from NumPy’s, thus
converting the `dtype`

to JAX’ default. This leaves us with
`jax_mat.dtype`

being `jnp.float32`

. However, PyTorch’s
`torch.from_numpy`

adapted the `dtype`

exactly, which is why PyTorch
had double the precision to work with.

With that knowledge, let’s make the test more fair by converting
`torch_mat`

to single-precision as well:

torch_mat = torch_mat.float() [ ['jax', jax_spectral_radius(jax_mat).item()], ['torch', torch_spectral_radius(torch_mat).item()], ]

`jax` |
1.4142135381698608 |

`torch` |
1.4142135381698608 |

Ahh, all is well; JAX is not lacking in terms of precision after all (at least in this small example).

Let’s check out what the square root of 2 is in single-precision to see just how precise we are:

two_f32 = np.array(2, dtype=np.float32) [['sqrt_f32', '{:.16f}'.format(np.sqrt(two_f32))]]

`sqrt_f32` |
1.4142135381698608 |

#### Gradients and Complex Numbers

##### Complex Differentiation

Let’s finally differentiate some complex numbers. You may not ever have seen this way to differentiate in PyTorch – it’s the functional way!

Just to clarify again, being able to take this gradient is a rather recent change in PyTorch: stable since PyTorch 1.9.0, June 2021. Originally, I wanted to show here that JAX is capable of differentiating something that PyTorch cannot. The competition has caught up, though!

jax_grad = jax.grad(jax_spectral_radius)(jax_mat) torch_mat.requires_grad_(True) torch_rho = torch_spectral_radius(torch_mat) torch_grad = torch.autograd.grad(torch_rho, torch_mat) decimals = 3 [ ['function', 'grad'], ['jax', round_tree(jax_grad.tolist(), decimals)], ['torch', round_tree(torch_grad[0].tolist(), decimals)], ]

function | grad |
---|---|

`jax` |
((0.354 -0.354) (0.354 0.354)) |

`torch` |
((0.354 -0.354) (0.354 0.354)) |

##### Setting Up Complex Input

complex_mat = np.array([[1 + 1j, -1], [1, 1]]) jax_complex_mat = jnp.array(complex_mat) torch_complex_mat = torch.from_numpy(complex_mat).to(torch.complex64) [ ['function', 'result'], None, ['jax', jax_spectral_radius(jax_complex_mat).item()], ['torch', torch_spectral_radius(torch_complex_mat).item()], ]

function | result |
---|---|

jax | 1.9021130800247192 |

torch | 1.9021130800247192 |

##### Differentiating Through Complex Gradients

Due to JAX doing some heavy abstract syntax tree (AST) work, it
includes a nice module with tree-related functions called
`jax.tree_util`

. We will use it to conjugate the tree of gradients we
obtain from `jax.grad`

(for no special reason at all).

jax_grad = jax.grad(jax_spectral_radius)(jax_complex_mat) jax_conj_grad = jax.tree_util.tree_map(jnp.conj, jax_grad) torch_complex_mat.requires_grad_(True) torch_rho = torch_spectral_radius(torch_complex_mat) torch_grad = torch.autograd.grad(torch_rho, torch_complex_mat) decimals = 3 [ ['type', 'gradient'], None, ['jax', round_tree(jax_grad.tolist(), decimals)], ['jax conj', round_tree(jax_conj_grad.tolist(), decimals)], ['torch', round_tree(torch_grad[0].tolist(), decimals)], ]

type | gradient |
---|---|

jax | (((0.38-0.616j) (-0.38-0.235j)) ((0.38+0.235j) (0.145-0.235j))) |

jax conj | (((0.38+0.616j) (-0.38+0.235j)) ((0.38-0.235j) (0.145+0.235j))) |

torch | (((0.38+0.616j) (-0.38+0.235j)) ((0.38-0.235j) (0.145+0.235j))) |

Oops! It seems there is a discrepancy here…

Without going too much into the math, when optimizing a function with complex inputs and real outputs, steepest-descent algorithms need the conjugate of the complex gradient in order to walk in the correct direction. As PyTorch is a deep learning framework first-and-foremost, it conjugates its gradients by default so users can go plug-and-play when fooling around with complex numbers and optimization.

##### Why Different Gradients?

- JAX chose mathematical/API consistency (also same behavior as in
Autograd
^{3}). - However,
**not**the gradients to use for steepest-descent optimization! (Conjugate before.) - PyTorch has the more practical default here.

## 6. `jit`

Compilation via XLA

### 6.1. Introducing `jit`

#### Side Effects in a `jit`

Context

##### Just-in-Time-Compiling via XLA

def print_hello(): print('Hello!') # Side effect! jax.jit(print_hello)()

Hello!

jit_print_hello = jax.jit(print_hello) jit_print_hello() jit_print_hello() print('... Hello? :(')

… Hello? :(

Before we dive in here, a quick terminology heads-up: “to JIT something”, means “to just-in-time-compile something”.

Multiple interesting things happened in these short snippets:

- We did not get any other “Hello!” after the first one in general.
- We did not get a second “Hello!” even though we applied
`jax.jit`

to`print_hello`

a second time. - We
**did**get a “Hello!” for the first call of the JITted function.

Let’s go through these in order:

- The
`print`

call is called a*side effect*in computer science. JAX does not care for these during JIT compilation, it only cares about math – or, to be more exact, whatever comes in and what comes out of the function. - The reason we do not get a second “Hello!” even though we apply
`jax.jit`

again (and we would expect the side effect to happen again before the function is compiled) is because JAX caches its compilation output in the background. So if we JIT the same function twice with the same arguments, the previous compilation output will be reused. JAX traces what happens inside the function on its first call, building a computational graph. This means, the first call of the function executes just like a standard Python function (though even slower due to the computational graph building).

This also means that JITting functions that result in a large computational graph (for example a Python loop that is executed very many times) can take forever to JIT only because the first tracing of it takes so long. When you encounter this issue, you can replace your loop with control flow substitutes from the

`jax.lax`

module.

##### JITting State

Here, we’ll see that `jax.jit`

can also be used as a decorator.
However, because we need to supply another argument to `jax.jit`

, we
cannot use it as a decorator that simply. Instead, we need to combine
it with the `functools.partial`

operator. An explanation follows after
the code block.

class Counter: count = 0 @functools.partial(jax.jit, static_argnums=(0,)) def inc(self): self.count += 1 a = Counter() print(a.count) a.inc() print(a.count) a.inc() print(a.count)

0 |

1 |

1 |

Applying `functools.partial`

like this results in the following
(actually anonymous) function:

# Result of `functools.partial(jax.jit, static_argnums=(0,))`. def partial_jit(*args, **kwargs): return jax.jit(*args, static_argnums=(0,), **kwargs)

This new `partial_jit`

function wraps the `inc`

method of `Counter`

,
resulting in the equivalent of the following:

class Counter: count = 0 def inc(self): self.count += 1 Counter.inc = jax.jit(Counter.inc, static_argnums=(0,))

I hope that helped make sense of the code. `static_argnums`

basically
tells `jax.jit`

to recompile the JITted function for a different
argument at that place. In return, we get some freedoms (for example,
we would not be able to JIT the function otherwise). We call the
arguments at the positions designated by `static_argnums`

*static*
from now on. More on static and non-static arguments later.

##### JITting State Again

a = Counter() print(a.count) a.inc() print(a.count) a.inc() print(a.count)

0 |

1 |

1 |

Due to `self`

being a static argument as specified via
`static_argnums`

, the function is recompiled for a new, different
`self`

^{4}.

#### Benchmarking `jit`

on `randn`

##### JITting our `randn`

We’ll now use Python’s `timeit`

module to benchmark a JIT-compiled
version of our old friend `randn`

(remember we implemented this in the
section on RNG in JAX). We implement a simple wrapper around it in
order to initialize the JIT-compiled function before we benchmark it.

You will notice some `block_until_ready`

calls. These are due to JAX’
asynchronous execution engine. Whenever JAX executes code on a
non-host device (such as a GPU), it happens asynchronously. This means
that the main thread of the program continues to run ahead with a
“mock” result, also called a “future”, while the actual result is
computed in the background. Only when we actually query the result
will we wait until it’s available.

During benchmarking, we would get these futures immediately – that’s
not much use to us. So we call the `block_until_ready`

function in
order to wait until the result of the computation is actually
available. You achieve the same in PyTorch using a call to
`torch.cuda.synchronize`

.

jit_randn = jax.jit(randn, static_argnums=(1, 2)) def time_str(code, number=5000): # Initialize cache exec(code) return timeit.timeit(code, globals=globals(), number=number) randn_time = time_str( 'randn(rng_key, (100, 100))[1].block_until_ready()') jit_randn_time = time_str( 'jit_randn(rng_key, (100, 100))[1].block_until_ready()') [ ['function', 'time [s]'], ['randn', randn_time], ['jit_randn', jit_randn_time], ]

##### JIT Results

function | time [s] |
---|---|

`randn` |
1.0273832870007027 |

`jit_randn` |
0.7757850260022678 |

A 25 % reduction! Not bad at all, especially since we shouldn’t really have much room to optimize here.

Let’s see how PyTorch does:

np_rng = np.random.default_rng() np_randn_time = time_str('np_rng.normal(size=(100, 100))') torch_randn_time = time_str( 'torch.randn((100, 100)); ' 'torch.cuda.synchronize()' ) [['np_randn', np_randn_time], ['torch_randn', torch_randn_time]]

`np_randn` |
0.6026334530001805 |

`torch_randn` |
0.25071304299990516 |

Apparently, PyTorch is super good at generating random numbers (3 times as fast as JIT-compiled JAX!). I did not analyze this and can’t say much more about this as there can be a myriad of reasons.

### 6.2. More about XLA

##### About XLA

- Works via completely known shapes.
- Can’t work dynamically with non-static values! That means (for
`jax.jit`

):**No**`if x > 0`

.**No**`jnp.unique`

.**No**`y[y % 2 == 0]`

or`y[:x]`

.- However, we can mark
`x`

(or what it depends on) as static. - Alternatively, “disable” JIT in section via experimental
`host_callback`

module.

- Most important optimization: operator/kernel fusion.
- Best to apply at outer-most location only
^{5}.

##### When XLA Recompiles

Function is recompiled for…

- different static argument values,
- different argument
*shapes*, and - different argument
*dtypes*.

When you hit performance issues, constant recompilation may be the problem!

##### When XLA Recompiles (in Text)

Imagine a dictionary (hash map) `compcache`

and a function with arguments
`args`

:

- For each argument
`x`

in`args`

, collect the following in`cache_key`

:- if it’s static,
`x`

(identity). - if it’s not static,
`(x.shape, x.dtype)`

,

- if it’s static,

`compcache`

maps from key `cache_key`

to JIT-output (value). Recompile
and insert if `compcache`

does not contain `cache_key`

.

Any static `x`

not hashable? Bzzzt, error!

##### When XLA Recompiles (in Code)

compcache = {} def maybe_compile(compcache, func, args): cache_key = [] for x in args: if is_static_arg(x): assert isinstance(x, collections.abc.Hashable) cache_key.append(x) # Identity else: # Imagine this works on Python primitives. cache_key.append((x.shape, x.dtype)) try: return compcache[cache_key] except KeyError: jit_output = xla_compile(func) compcache[cache_key] = jit_output return jit_output

Just an intuitive example! For example, fails with arbitrarily ordered
keyword arguments; `cache_key`

should be a `dict`

.

##### JIT and C++

- Possible to run JIT-compiled functions from C++ via XLA
runtime
^{6}and`jax_to_hlo`

utility. - Intermediate representation (IR) output from
`jax.jit`

will be JIT-compiled in C++ program. - However, a bit involved.
- Example in JAX repository.

### 6.3. LLVM is Smart… And XLA?

In this section, I’d like to show you a clever optimization compilers do for you.

We’ll take a look at a simple sum implementation in C and the code generated from it. We will compare that with several implementations in Python, compiling the JAX version and seeing why (spoiler alert) XLA does not match up in terms of performance, even though it also uses LLVM for compilation.

#### Scalar Evolution with Clang

##### Giant’s Shoulders

#include <stdio.h> int sum(int limit) { int i, total; total = 0; for (i = 0; i < limit; ++i) { total += i; } return total; } int main(int argc, char** argnv) { printf("%d\n", sum(50000)); return 0; }

1249975000 |

If you’re following along, save the above to a file called `sum.c`

.

This is the simplest sum function we can implement. Let’s see what a modern C compiler does to this code: by outputting LLVM’s lower-level representation.

##### Outputting LLVM IR

Intermediate representation (IR) is a lower-level (in this case assembly-like) representation of the program. The compiler backend LLVM uses IR to achieve portability across assembly languages.

Don’t worry too much about the `vim`

call below – we are simply
filtering the LLVM IR output so it only shows the definition of the
`sum`

function. The `sed`

call strips trailing spaces.

clang -S -emit-llvm sum.c -O1 cat sum.ll \ | vim - +'/^define.*sum(.*{$/,/^}$/p' \ -es --not-a-term \ | sed 's/ *$//'

Clang is the official C frontend^{7} of the LLVM project.

##### LLVM Scalar Evolution

Warning: assembly-like language below! Don’t worry about reading this, I’ll give the gist of it below.

define dso_local i32 @sum(i32 %0) local_unnamed_addr #0 { %2 = icmp sgt i32 %0, 0 br i1 %2, label %3, label %13 3: ; preds = %1 %4 = add i32 %0, -1 %5 = zext i32 %4 to i33 %6 = add i32 %0, -2 %7 = zext i32 %6 to i33 %8 = mul i33 %5, %7 %9 = lshr i33 %8, 1 %10 = trunc i33 %9 to i32 %11 = add i32 %10, %0 %12 = add i32 %11, -1 br label %13 13: ; preds = %3, %1 %14 = phi i32 [ 0, %1 ], [ %12, %3 ] ret i32 %14 }

So, without focusing too much on the details and interpreting this
intuitively: please believe me that LLVM converted our `sum`

function
that used a `for`

-loop into… the closed-form `sum`

formula \( n
\cdot (n - 1) / 2 \) (minus instead of plus due to the limit being
excluded)! Isn’t that amazing?

This form of optimization is called *scalar evolution* and is an
induction-based technique for – as you can see – quite
substantial performance improvements. If you became interested in this
topic, in the next section follows the link to the source code which
includes references to the papers it implements.

If you really wanted to make sure the optimization happens on the
machine code level, you can compile the code to an object file and
disassemble it using for example the `radare2`

program.

##### More on Scalar Evolution

LLVM scalar evolution analysis (SCEV) source code with links to papers.

Weirdly enough, sums with a step other than 1 are not optimized even though a closed-form solution exists…

To explain a bit more, an earlier version had an ```
int sum(int limit,
int step)
```

function that allowed a varying step size. However, LLVM
did not optimize this function to the closed-form solution, even
though it really should be able to (from what I could see in the
comments of the scalar evolution source code).

##### Benchmarking C Sum

This section is about obtaining a C timing for the `sum`

function and
can safely be skipped.

In order to get some timings for C, which does not include a nice
`timeit`

module, what follows is a benchmark program for the above
`sum`

function allowing various timing methods. The idea is to mimic
`timeit`

with this. I saved this to a file called `benchmark_sum.c`

.

#include <stdio.h> #include <time.h> // The CPU cycle-based timings suffer from poor resolution compared to // wall-time measurements. #define MY_CLOCK 0 #define MY_CLOCK_GETTIME 1 #define MY_CLOCK_GETTIME_WALL 2 #define MY_TIME 3 #define MY_CLOCK_FUN MY_CLOCK_GETTIME_WALL int sum(int limit) { int i, total; total = 0; for (i = 0; i < limit; ++i) { total += i; } return total; } int main(int args, char** argv) { double duration; int res; #if MY_CLOCK_FUN == MY_CLOCK clock_t start_time, end_time; start_time = clock(); #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME struct timespec start_time, end_time; clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start_time); #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME_WALL struct timespec start_time, end_time; clock_gettime(CLOCK_MONOTONIC, &start_time); #elseif MY_CLOCK_FUN == MY_TIME time_t start_time, end_time; #endif res = sum(50000); #if MY_CLOCK_FUN == MY_CLOCK end_time = clock(); duration = (double) (end_time - start_time) / CLOCKS_PER_SEC; #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end_time); duration = ( end_time.tv_sec + 1e-9 * end_time.tv_nsec - start_time.tv_sec - 1e-9 * start_time.tv_nsec ); #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME_WALL clock_gettime(CLOCK_MONOTONIC, &end_time); duration = ( end_time.tv_sec + 1e-9 * end_time.tv_nsec - start_time.tv_sec - 1e-9 * start_time.tv_nsec ); #elseif MY_CLOCK_FUN == MY_TIME end_time = time(); duration = difftime(end_time, start_time); #endif // Use `res` so the `sum` call is not optimized away. printf("sum %d\n", res); printf("dur %.17g\n", duration); return 0; }

Check that we get the sum formula optimization we want. I also
manually checked to make sure that `main`

does not optimize away the
`sum`

call.

clang -S -emit-llvm benchmark_sum.c -O1 cat benchmark_sum.ll \ | vim - +'/^define.*sum(.*{$/,/^}$/p' \ -es --not-a-term \ | sed 's/ *$//'

define dso_local i32 @sum(i32 %0) local_unnamed_addr #0 { %2 = icmp sgt i32 %0, 0 br i1 %2, label %3, label %13 3: ; preds = %1 %4 = add i32 %0, -1 %5 = zext i32 %4 to i33 %6 = add i32 %0, -2 %7 = zext i32 %6 to i33 %8 = mul i33 %5, %7 %9 = lshr i33 %8, 1 %10 = trunc i33 %9 to i32 %11 = add i32 %10, %0 %12 = add i32 %11, -1 br label %13 13: ; preds = %3, %1 %14 = phi i32 [ 0, %1 ], [ %12, %3 ] ret i32 %14 }

Execute the same number of times as we do with `time_str`

and add up
the timings.

clang -o benchmark_sum.o benchmark_sum.c -O1 ./benchmark_sum.o seq 5000 \ | xargs -n 1 ./benchmark_sum.o \ | awk '/dur/ {total+=$2} END {print total}'

sum 1249975000 dur 9.532824124368238e-130 4.76641e-126

I copy-pasted this number at a later location so that I was able to give a live surprise.

#### Clang/Math vs. XLA

##### How does XLA ~~stack~~ sum up?

Back to Python, let’s see whether XLA timings match Clang. Also, let’s acknowledge Python’s miserable performance.

def py_sum_up(limit): return sum(range(limit)) def jax_sum_up(limit): return jnp.sum(jnp.arange(limit)) limit = 50000 python_time = time_str('py_sum_up(limit)') jax_time = time_str('jax_sum_up(limit).block_until_ready()') jit_sum_up = jax.jit(jax_sum_up, static_argnums=(0,)) jax_jit_time = time_str('jit_sum_up(limit).block_until_ready()') [ ['function', 'time [s]', 'result'], ['py_sum_up', python_time, py_sum_up(limit)], ['jax_sum_up', jax_time, jax_sum_up(limit).item()], ['jit_sum_up', jax_jit_time, jit_sum_up(limit).item()], ]

##### Sum Timings

function | time [s] | result |
---|---|---|

`py_sum_up` |
3.0464433249981084 | 1249975000 |

`jax_sum_up` |
0.44727893300296273 | 1249975000 |

`jit_sum_up` |
0.1480330039994442 | 1249975000 |

##### JIT vs. Math

Here, we manually implement the sum formula optimization. To get a
better idea how fast `jit_sum_up`

*should* be in the best case.

def sum_up_const(limit): # Since we exclude the limit, subtract one instead of adding. return limit * (limit - 1) // 2 [ ['function', 'time [s]', 'result'], [ 'jit_sum_up', jax_jit_time, jit_sum_up(limit).item(), ], [ 'sum_up_const', time_str('sum_up_const(limit)'), sum_up_const(limit), ], ]

##### JIT vs. Math Results

function | time [s] | result |
---|---|---|

`jit_sum_up` |
0.1480330039994442 | 1249975000 |

`sum_up_const` |
0.0008629900003143121 | 1249975000 |

Reason? The optimization is too slow to apply and was disabled (but
only on the CPU)!^{8}

C with `-O1`

takes around 5e-126 seconds.

#### Even More on JIT and Math

##### More JIT vs. Math

Just for fun, we try the same in PyTorch, using its JIT. We need to
save the following code to a file `torch_sum.py`

so PyTorch can parse
the (TorchScript) source code.

import torch __all__ = ['torch_sum_up', 'torch_jit_sum_up'] def torch_sum_up(limit): return torch.sum(torch.arange(limit)) torch_jit_sum_up = torch.jit.script(torch_sum_up)

from torch_sum import * torch_limit = torch.tensor(limit) [ ['function', 'time [s]', 'result'], [ 'torch_sum_up', time_str( 'torch_sum_up(torch_limit); ' 'torch.cuda.synchronize()' ), torch_sum_up(torch_limit).item(), ], [ 'torch_jit_sum_up', time_str( 'torch_jit_sum_up(torch_limit); ' 'torch.cuda.synchronize()' ), sum_up_const(torch_limit).item(), ], ]

function | time [s] | result |
---|---|---|

`torch_sum_up` |
0.15606207399832783 | 1249975000 |

`torch_jit_sum_up` |
0.1805207170000358 | 1249975000 |

##### XLA-Optimized LLVM IR

If you’ve been executing `jax.jit`

code snippets, you can find XLA IR
output in the `generated`

directory (if you have set the `XLA_FLAGS`

as in the behind-the-scenes setup code block).

## 7. `vmap`

: No Batching Required

##### What is Batching

If you don’t know what *batching* is, here are some simple examples.
First three non-batched versions, then a batched version. Whether
you’re using R, Octave/MATLAB, NumPy, or PyTorch, you **always** want to
batch your calculations for optimum performance. Especially when you
are interested in taking gradients, batching greatly simplifies the
computational graph.

For the setting, assume we have a small set of 15 3-dimensional edges and we want to sum up their norms because we need it for some computer graphics algorithm.

Non-batched:

rng_key, edges = randn(rng_key, (15, 3)) norm_sum = 0 for edge in edges: norm_sum += jnp.linalg.norm(edge) norm_sum

27.265522 |

Now, we can write this in a more pythonic way by using the built-in
`sum`

function. However, we are still not batching.

sum(jnp.linalg.norm(edge) for edge in edges)

The following is a more NumPy-like way to write the non-batched version. It may even be faster than the pythonic version due to being able to use SIMD operations. Whether performance is gained or lost depends on the size of our dataset, though.

jnp.sum(jnp.array([jnp.linalg.norm(edge) for edge in edges]))

Finally, we arrive at the batched version. We avoid any and all Python loops, calculating our norm sum in a much more efficient manner.

Batched:

norms = jnp.linalg.norm(edges, axis=-1) # shape: (15,) norm_sum = jnp.sum(norms) norm_sum

27.265522 |

`vmap`

: No Batching Required

Now, what if I told you you no longer needed to do batching manually?
Enter `jax.vmap`

!

The following example calculates the spectral radius on a 128-size
batch of 3×3 matrices. With the `assert`

statement, we make sure
our old version `jax_spectral_radius`

is not batched already.

Notice also that you can combine `jax.jit`

with `jax.vmap`

– any of
the function transformations in JAX are arbitrarily nestable; quite
the magic! I’ll let the below timings speak for themselves.

rng_key, batch = jit_randn(rng_key, (128, 3, 3)) assert jax_spectral_radius(batch).shape != (128,) def looped_spectral_radius(batch): return list(map(jax_spectral_radius, batch)) jit_looped_spectral_radius = jax.jit(looped_spectral_radius) batched_spectral_radius = jax.vmap(jax_spectral_radius) jit_batched_spectral_radius = jax.jit(batched_spectral_radius)

function | time [s] |
---|---|

`looped` |
10.506742246001522 |

`jit_looped` |
1.7194310759987275 |

`batched` |
3.65071794799951 |

`jit_batched` |
0.9865755659993738 |

##### More Batching

Because we can’t get enough of comparisons across frameworks, here are some more batched implementations.

Since we are going to JIT-compile PyTorch’s TorchScript again, you
need to save the following code block to `torch_batching.py`

.

import torch __all__ = ['torch_batched_spectral_radius', 'torch_jit_batched_spectral_radius'] def torch_batched_spectral_radius(mat): eigvals = torch.linalg.eigvals(mat) spectral_radius = torch.max(torch.abs(eigvals), -1) return spectral_radius torch_jit_batched_spectral_radius = torch.jit.script( torch_batched_spectral_radius)

from torch_batching import * def np_batched_spectral_radius(mat): eigvals = np.linalg.eigvals(mat) spectral_radius = np.max(np.abs(eigvals), -1) return spectral_radius def jax_manually_batched_spectral_radius(mat): eigvals = jnp.linalg.eigvals(mat) spectral_radius = jnp.max(jnp.abs(eigvals), -1) return spectral_radius jax_jit_manually_batched_spectral_radius = jax.jit( jax_manually_batched_spectral_radius) np_batch = np.array(batch) torch_batch = torch.from_numpy(np_batch) [ ['function', 'time [s]'], [ 'np_batched', time_str('np_batched_spectral_radius(np_batch)'), ], [ 'jax_manually_batched', time_str( 'jax_manually_batched_spectral_radius(batch).block_until_ready()'), ], [ 'jax_jit_manually_batched', time_str( 'jax_jit_manually_batched_spectral_radius(batch)' '.block_until_ready()'), ], [ 'torch_batched', time_str( 'torch_batched_spectral_radius(torch_batch); ' 'torch.cuda.synchronize()' ), ], [ 'torch_jit_batched', time_str( 'torch_jit_batched_spectral_radius(torch_batch); ' 'torch.cuda.synchronize()' ), ], ]

##### More Batching Results

function | time [s] |
---|---|

`np_batched` |
0.8191060569988622 |

`jax_manually_batched` |
1.0442584550000902 |

`jax_jit_manually_batched` |
0.8894995779992314 |

`torch_batched` |
1.276841124999919 |

`torch_jit_batched` |
1.373109233998548 |

When comparing these timings, keep in mind that we are not running on the GPU.

## 8. `pmap`

: Simple, Differentiable MPI

`pmap`

- Simple function transformation for splitting a computation across devices.
- Certain
`jax.lax`

primitives for reduction (`pmean`

,`psum`

, …). - Like to stay old school? Differentiable
`mpi4jax`

^{9}.

To activate for local tests (adjust `num_devices`

as desired):

import multiprocessing import os num_devices = multiprocessing.cpu_count() os.environ['XLA_FLAGS'] = ( os.getenv('XLA_FLAGS') + ' --xla_force_host_platform_device_count=' + str(num_devices) )

`jax.pmap`

is actually not just related to `jax.vmap`

in name – the
functions do the exact same thing, just different: `jax.vmap`

batches
its function and can be imagined as a `for`

-loop over the mapped-over
axis. `jax.pmap`

also batches its function but is instead a parallelly
executed `for`

-loop. That’s really all there is to it; when you know
`jax.vmap`

, you know `jax.pmap`

. Of course, the parallelism offers
some extra functionalities which is exactly what we are going to be
learning in this section.

`pmap`

and Axes

- Splits computation according to an axis; works over multiple devices and/or JAX processes.
- This broadcasting axis must be the size of our devices/processes, so
reshape your data accordingly:
- assuming 4 devices/processes and broadcasting axis 0, reshape
dataset of shape
`(128, 3, 3)`

to`(4, 32, 3, 3)`

. In code:

world_size = jax.device_count() # or `jax.process_count()` dataset = dataset.reshape(world_size, -1, dataset.shape[1:])

- assuming 4 devices/processes and broadcasting axis 0, reshape
dataset of shape

### 8.1. Write your own Horovod!

The following is a case study of using JAX to parallelize simple deep learning code. We are going to train a simple multilayer perceptron and teach it to calculate spectral radii.

I referenced the parallel deep learning framework Horovod because it seems to be the most-used tool at Jülich Supercomputing Centre for model-parallel training.

You can substitute “Horovod” with whatever you like to use, though; be
it MPI, `tf.distribute.Strategy`

, `torch.distributed`

,
`torch.distributed`

with `torch.nn.parallel.DistributedDataParallel`

(DDP), or whatever else you know and love. The principles are all the
same, although our JAX implementation is more similar to Horovod than
PyTorch DDP, for example (due to global instead of node-local
splitting of the batch).

#### Non-Distributed Setup

##### Training Code

What follows is some boilerplate setup code for deep learning using
JAX’ built-in example machine learning libraries
`jax.example_libraries.stax`

and `jax.example_libraries.optimizers`

.
Notice how all state is explicitly handled with these. If you don’t
care for this, you can skip straight to the interesting part.

I also quickly want to mention why the code is a bit larger than it
could be: I like my model to be able to work with dynamic batch sizes;
however, if you follow JAX’ official example, it would seem that the
batch size needs to be fixed. By implementing two little extras, we
are able to handle arbitrary batch sizes. First, we implement our own
small `flatten`

function/layer that flattens the whole input (unlike
`stax.Flatten`

). Second, we apply `jax.vmap`

to our `model`

function.
The `model`

function is always passed the model’s parameters as well,
that is why we do not want to `vmap`

over the first argument
(indicated by `in_axes=(None, [...])`

). And that’s it!

The only slight inconvenience with this setup is that we need to add an extra size-1 batch dimension in order to handle singular inputs. You’ll see this when we test our model later.

from jax.example_libraries import stax from jax.example_libraries import optimizers input_shape = (2, 2) def flatten(): def init_fun(rng_key, input_shape): flattened_size = jnp.prod(jnp.array(list(input_shape))) return (flattened_size,), () def apply_fun(params, inputs, **kwargs): return inputs.ravel() return init_fun, apply_fun def build_model(rng_key): model_init, model = stax.serial( flatten(), stax.Dense(64), stax.Relu, stax.Dense(64), stax.Relu, stax.Dense(64), stax.Relu, stax.Dense(1), ) # Handle varying batch sizes. model = jax.vmap(model, in_axes=(None, 0)) rng_key, subkey = jax.random.split(rng_key) output_shape, params = model_init(subkey, input_shape) assert output_shape == (1,) return rng_key, params, model def build_opt(params): opt_init, update, get_params = optimizers.adam(3e-4) opt_state = opt_init(params) return opt_state, update, get_params rng_key, params, model = build_model(rng_key) orig_opt_state, opt_update, get_params = build_opt(params)

##### Interesting Part of the Training Code

Here we implement our update methods. Notice how we use
`batched_spectral_radius`

instead of `jit_batched_spectral_radius`

in
order to give XLA more optimization freedom. Also, here we see
conjugating the possibly complex gradients in action.

def batch_loss(params, batch): preds = model(params, batch) targets = batched_spectral_radius(batch) return jnp.mean(jnp.abs(preds - targets)) @jax.jit def train_batch(step, opt_state, batch): params = get_params(opt_state) loss, grad = jax.value_and_grad(batch_loss)(params, batch) # Conjugate gradient for steepest-descent optimization. grad = jax.tree_util.tree_map(jnp.conj, grad) opt_state = opt_update(step, grad, opt_state) return opt_state, loss

##### Training a Spectral Radius MLP

A really simple deep learning training loop. We generate our batches
on-demand, taking care to update our `rng_key`

, of course!

opt_state = orig_opt_state batch_size = 64 batch_shape = (batch_size,) + input_shape steps = 10000 log_step_interval = 1000 start_time = time.perf_counter() for step in range(steps): rng_key, batch = jit_randn(rng_key, batch_shape, dtype=complex) opt_state, loss = train_batch(step, opt_state, batch) if step % log_step_interval == 0: print('step ', step, '; loss ', loss, sep='') end_time = time.perf_counter() print('Training took', end_time - start_time, 'seconds.')

##### Training Results

step 0; loss 1.6026156 step 1000; loss 0.39599544 step 2000; loss 0.38151193 step 3000; loss 0.4386006 step 4000; loss 0.3645811 step 5000; loss 0.38383436 step 6000; loss 0.4037715 step 7000; loss 0.3104779 step 8000; loss 0.32223767 step 9000; loss 0.40970623 Training took 7.869086120001157 seconds.

Okay, that’s some sound old MLP training. Let’s get into parallelization already.

#### Multi-Node Distribution

A quick interlude on some extra distributed use cases and GPU memory pre-allocation. These are only interesting if you plan to distribute code yourself and are more code snippets I wanted to leave here as a reference. To skip these sections, click here.

The following is some setup code you would use on a Slurm-managed cluster, for example. But first, a word of caution…

##### Multi-Node Distributed Setup

- Caution: Experimental, undocumented and
**not even used, much less tested,**anywhere inside JAX! Future versions will have a function

`jax.distributed.initialize`

, working much like PyTorch’storch.distributed.init_process_group( [...], init_method="tcp://[...]", # or "env://" )

##### Multi-Node Distributed Setup Code

Adapted from JAX source code (clickable version of link below):

# Adapted and compressed from # https://github.com/google/jax/blob/4d6467727709de1c9ad220ac62783a18bcbf4990/jax/_src/distributed.py def jax_distributed_initialize( coordinator_address, num_processes, process_id): if process_id == 0: global _service _service = \ jax.lib.xla_extension.get_distributed_runtime_service( coordinator_address, num_processes) client = jax.lib.xla_extension.get_distributed_runtime_client( coordinator_address, process_id) client.connect() factory = functools.partial( jax.lib.xla_client.make_gpu_client, client, process_id) jax.lib.xla_bridge.register_backend_factory( 'gpu', factory, priority=300)

#### Handling GPU Memory Pre-Allocation

If you ever had trouble with running out of GPU memory when using multi-process TensorFlow, you may have fixed it by enabling “GPU memory growing”. (The default is that TensorFlow pre-allocates a large block of memory in order to reduce memory fragmentation.)

JAX does the same, so in case you need it, what follows is the JAX equivalent to enabling GPU memory growing in TensorFlow.

##### Disable GPU Memory Pre-Allocation

Equivalent of the following TensorFlow:

import tensorflow as tf gpus = tf.config.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)

In JAX:

import os # Before first JAX computation. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

#### Distributed Training

##### Distributing our Training Code

With all of that out of the way, let’s finally write some distributed training code!

… What’s that? We only need to add a single line in our
`train_batch`

function?

Well, almost; we also need to apply the titular `jax.pmap`

function
transformation. I’ll explain what’s going on here after the code
block.

batch_axis = 'batch' def distributed_train_batch(step, opt_state, batch): params = get_params(opt_state) loss, grad = jax.value_and_grad(batch_loss)(params, batch) # This is the only line we had to add to `train_batch`. loss, grad = jax.lax.pmean((loss, grad), batch_axis) # Conjugate gradient for steepest-descent optimization. grad = jax.tree_util.tree_map(jnp.conj, grad) opt_state = opt_update(step, grad, opt_state) return opt_state, loss # `pmap` also `jit`s our function. pmap_train_batch = jax.pmap( distributed_train_batch, batch_axis, in_axes=(None, None, 0), out_axes=(None, None), )

This is already all the magic behind Horovod! (Ignoring the super-optimized communication.)

The first thing that should have caught your eye is the definition of
`batch_axis`

at the very top of the code block. This `batch_axis`

is
passed to both `jax.lax.pmean`

and `jax.pmap`

as the reduction axis
and mapped-over axis, respectively. We need to do this because – and
I hope you’ve started to notice a pattern here – `jax.pmap`

is **also**
nestable! By having to specify an axis for `jax.pmap`

, the reduction
operations in `jax.lax`

will always have an axis to refer to. We use a
string to name the axis here, but any hashable would do.

The call to `jax.lax.pmean`

averages a tree of values over all
processes. In our case, it averages the loss and gradient. We average
the loss as well here because you usually want to log the globally
averaged loss instead of the local loss to get a smoother overall
picture (in Horovod, you need to enable this explicitly). The averaged
gradient is then used to do the same update step on each process, so
we don’t need any more synchronization afterwards.

`jax.pmap`

has some other arguments here we haven’t seen yet, namely
`in_axes`

and `out_axes`

. `jax.vmap`

accepts these too, and they are
very important! With them, you control the axis of each argument that
the function transformation maps over. If you don’t want to map over
an argument (maybe you are passing a single constant that you don’t
want to copy until it has the size of the mapped-over axis), you
specify `in_axes`

at the position of the argument as `None`

. We do
this for the training step (a single integer) and model parameters (a
tree of matrices). However, we *do* want to map over the batch
somehow. We specify the very first axis here.

`out_axes`

is similar, but for the output. If we wanted to collect the
different outputs of the function transformation for each mapped-over
input, we would specify the return values that we want to collect and
the axis we want to collect them over at the corresponding positions
in `out_axes`

. Since we already reduce over the mapped-over values
with the `jax.lax.pmean`

call, we will not have multiple different
outputs and thus use `None`

just like for `in_axes`

to prevent the
collection.

##### Training a Spectral Radius MLP Distributively

Here, we reshape the original batch so we can split it over its first
axis as desired. In the code, we obtain the `world_size`

as the number
of devices known to JAX. Going back to the multi-node setup code from
before, if you initialize your distributed training like that,
i.e. by starting multiple Python processes, you may want to
use `world_size = jax.process_count()`

instead. Like so often, this
depends on your use case.

opt_state = orig_opt_state # Since we are only using a single JAX process with (possibly # emulated) multiple devices, we use `jax.device_count`. # Usually, this would be `jax.process_count`. world_size = jax.device_count() assert batch_size % world_size == 0 local_batch_size = batch_size // world_size start_time = time.perf_counter() print('Training with', world_size, '(possibly simulated) devices.') for step in range(steps): rng_key, batch = jit_randn(rng_key, batch_shape, dtype=complex) batch = batch.reshape( (world_size, local_batch_size) + batch.shape[1:]) opt_state, loss = pmap_train_batch(step, opt_state, batch) if step % log_step_interval == 0: print('step ', step, '; loss ', loss, sep='') end_time = time.perf_counter() print('Distributed training took', end_time - start_time, 'seconds.')

##### Distributed Training Results

Training with 2 (possibly simulated) devices. step 0; loss 1.6025591 step 1000; loss 0.3969112 step 2000; loss 0.38199607 step 3000; loss 0.4381593 step 4000; loss 0.36498725 step 5000; loss 0.38379186 step 6000; loss 0.40361017 step 7000; loss 0.3104655 step 8000; loss 0.32245204 step 9000; loss 0.40979913 Distributed training took 7.309056613001303 seconds.

Even though I only used 2 simulated devices, training did speed up a bit already. Nice!

##### Did it Learn?

Let’s test whether our model actually learned anything useful. Maybe
the logged losses don’t tell the whole story. Also, please notice that
I JITted the `model`

function for inference here; I wanted to show
this off as it will really speed up your inference times. (Of course,
JITting did not really make sense here since I only use the function
once.)

params = get_params(opt_state) jit_model = jax.jit(model) ceig_mat = jnp.array([[1.0, -1.0], [1.0, 1.0]]) batched_ceig_mat = jnp.expand_dims(ceig_mat, 0) [ ['function', 'spectral radius sample'], [ 'jax_spectral_radius', jax_spectral_radius(ceig_mat).item(), ], [ 'model', jit_model(params, batched_ceig_mat)[0].item(), ], ]

function | spectral radius sample |
---|---|

`jax_spectral_radius` |
1.4142135381698608 |

`model` |
1.4250930547714233 |

## 9. Summary

##### Advantages

- Educational documentation.
- Familiar API.
- Interoperate with TensorFlow using experimental
`jax2tf`

(included in JAX; despite the name, also supports TensorFlow to JAX). - Faster code!
- More explicit, less magic, less trouble understanding.
- Functional style
**will**avoid headaches. Look out for stabilization of

`pjit`

(previously`sharded_jit`

); even simpler Horovod, ability to JIT**huge**functions!`pjit`

is actually much more cool than you would expect. The old name is a bit more descriptive here, as`pjit`

– aside from being a more abstracted`pmap`

– allows us to even split super large functions that do not fit in the memory of a single device.

##### Disadvantages

Initial hurdles:

- Getting used to it.
- Sometimes not as quick to write; however, payoff in the long term.

Better with time:

- Sometimes unpredictable or unstable.
- Lacking ecosystem.

Will never change:

- Hidden functionalities/undocumented APIs; some useful code (intentionally?) not public.
- Mutating state during JITting
**will**cause headaches. Backend in TensorFlow: code split and dependency.

The code split means that the TensorFlow repository also contains JAX-only code. So you have another code location to keep in mind if you need to dive deeper into the JAX JIT for some reason.

##### Neural Network Libraries

`jax.example_libraries.stax`

- Included in JAX, bare-bones and requires more manual work. However, simplicity is an advantage.
`flax`

- Most features, most user-friendly in my opinion.
`haiku`

- Goes
*against*JAX’ implicit/immutable state but converts models back to stateless transformations. Thus, maybe better user experience. `objax`

- Similar API to
`torch.nn.Module`

. `trax`

- Focus on sequence data and large-scale training.

These are all made by the same company!

Just to clarify, while the above statement can be interpreted as a cheap jab at Alphabet and the TensorFlow situation, I do believe that designing sensible APIs for JAX’ paradigms is hard. Most of these libraries can most likely be viewed as experiments with regard to this design.

`functorch`

- JAX-like function transformations (
`vmap`

,`grad`

, …) - Stateless models

… for PyTorch! Experimental, but best of both worlds.

When you ever have trouble batching something in PyTorch, it may help. Will probably be a while until it’s included in PyTorch.

By the way, PyTorch also has a JIT!

Since this is the extended version, you did see some results with PyTorch’s JIT. All of those did not seem… promising. The non-JITted PyTorch code was consistently as fast or even faster than the JITted version.

However, I believe that PyTorch’s JIT compiler’s use case is a bit different, leaning more towards optimizing deep learning models for inference. I don’t understand why the super simple linear algebra we tried to JIT did not get optimized (un-optimized, instead!) at all, but I did not dive into PyTorch’s JIT and thus can’t say too much about this.

##### Thanks for Reading!

Thank you for your attention.

I hope you had as much fun as I had preparing.

## 10. Appendix

##### References

- JAX source repository (accessed from 2021-11-18 to 2021-11-26)

##### Extra Recommendations

I recommend this article and corresponding paper on limitations of XLA and TorchScript compilation.

Maybe of interest: there are already JAX libraries for differentiable rigid-body physics simulation and molecular dynamics. You will find others fields covered as well.

## Footnotes:

^{1}

Stable since PyTorch 1.9.0, released in June 2021.

^{3}

Some of the Autograd developers took part in building JAX upon its ideas.

^{4}

Why is the second `a`

considered different from the first? The
default `__hash__`

implementation is based on the object’s `id`

, its
location in memory. (Both of these are CPython implementation
details.)

^{5}

That way, the compiler gets the most options for optimization.
Also, an already `jit`

ted inner function cannot be optimized further.

^{6}

There are multiple XLA runtimes.

^{7}

C family of languages, to be exact.

^{8}

Not quite sure since they only disable `--loop-unroll`

.

^{9}

The Message Passing Interface (MPI) is a standard in parallel computing.