Scientific Computing with JAX: A Case Study (Transcript)

article
RSE
JAX
HPC
AI transcribed companion to the Durham HPC Days 2025 talk: porting a gravitational lensing likelihood from Numba to JAX for JWST analysis.
Author

Kolen Cheung

Published

June 4th, 2025

Side to the talk

Introduction and Motivation

Hello everyone. My name is Kolen Cheung, and I’m a Research Software Engineer at the University of Exeter. Today, I’m going to talk about a project I’ve worked on to adopt JAX for scientific computing, specifically a case study in evaluating gravitational lensing likelihood.

The project’s motivation is to probe the nature of dark matter with the James Webb Space Telescope (JWST). The science case is to investigate the low-mass region of dark matter substructures in our universe. This region shows statistical differences between the well-established Lambda-CDM model and alternative dark matter models. Since this region is generally invisible, we need to probe it via strong gravitational lensing. This was first demonstrated in a paper two years ago using Hubble Space Telescope (HST) data, so the natural next step is to take advantage of the high-quality observations from JWST’s multiple wavebands.

Methodology and Computational Challenges

To do this, we use an existing package called PyAutoLens. It performs mass modeling using parametric, nonlinear, multi-phase models to predict what kind of image can be observed. Another part of the code performs source reconstruction, which reconstructs the original, un-lensed image from the observed lensed image. A log-likelihood function then takes the output of the modeling and computes the likelihood. The key goal of this package is to automate the whole process and apply it to large datasets.

In cosmology, we cannot perform experiments; we only have one observable universe. All we can do is make observations and then use simulations to infer parameters of physical models through statistical methods. PyAutoLens uses a sampler like dynesty to iterate and find the best solution. For example, with a 25-parameter mass model, the goal is to find the mass profile that produces the observed image. As the model iterates, the chi-squared value decreases, and the predicted lens image more closely matches the observation, successfully estimating the mass model.

However, there are computational challenges. PyAutoLens, originally implemented in Numba, takes around 48 hours on 76 CPU cores to process single-band HST data. This involves about 100,000 iterations, with each iteration taking about three seconds. The next step is scaling this to JWST data, which involves multiple frequency bands. Preliminary benchmarks show that the runtime is 20 times longer. Because JWST provides data across multiple frequency bands, we get color information that helps guide the model to better solutions.

Why JAX?

The question is, why port PyAutoLens to JAX for calculating the log-likelihood function? Two other packages for strong gravitational lensing demonstrated successful adoption of JAX in their modeling and saw huge speedups. Performance gains can come from several sources: JAX code can be inherently faster, it can be deployed on accelerators like GPUs, or the availability of gradient information can reduce the number of iterations needed to find the best-fit solution. This project involved porting a subset of the likelihood function from Numba to JAX.

Porting and Programming Experience

My methodology for this porting process was: 1. Translate the existing functions back to their mathematical representation. 2. Translate the math into Numba using vectorized (array) programming to anticipate the paradigm in JAX. 3. Translate the Numba function to JAX, which in simple cases works directly. 4. Further optimize the code based on the situation.

I organized the project into three modules: one for the original functions, one for the Numba functions from step two, and one for the JAX functions from step three. I used metaprogramming to set up a unit test framework with pytest to guarantee correctness and pytest-benchmark to compare performance, which then fed back into the optimization step.

Numba vs. JAX

Numba is a just-in-time (JIT) compiler for a subset of Python and NumPy operations, powered by LLVM. While it can target GPUs via CUDA, this requires rewriting functions with different APIs and is limited to CUDA.

JAX is a tracing JIT compiler from Google, powered by the XLA compiler. It’s designed primarily for machine learning but is also suitable for scientific computing. As a tracing compiler, it encourages a functional programming paradigm by removing side effects. It automatically targets multiple hardware architectures (CPU, GPU, TPU) without requiring code rewrites, thus solving the so-called “two-language problem” (or, in this case, the “three-implementation problem” of Python API, CPU, and GPU code).

Both JAX and Numba are effectively domain-specific languages. Numba is more of a C-like mini-language, while JAX is a smaller language with more restrictions on control flow, mutation, and dynamic shapes. Numba implements a subset of NumPy and SciPy, but documentation on its internal workings can be minimal. In contrast, jax.numpy and jax.scipy have their own comprehensive documentation, which facilitates deviations from the reference implementations when needed.

Numba functions recompile whenever an input type changes. JAX does this as well but also recompiles on shape changes, which has important consequences. JAX provides automatic accelerator offloading and auto-differentiation, but this means that calling external code from a JIT-compiled function incurs extra costs, such as memory transfer to and from devices and the loss of auto-grad capabilities.

JAX Characteristics and Benchmark Analysis

Being a tracing compiler that recompiles per shape means that if you want the shape to be part of your input, you need to use static_argnums, which triggers a recompile every time that input changes. Framing your problem using JAX’s idiomatic expressions can result in great speedups, sometimes more than you could achieve in Numba. This also means performance improvements come for free with compiler updates. It’s easy to port functions to a GPU without having one set up locally; you can develop on a CPU and then deploy on a GPU system. The XLA compiler handles device-specific optimizations automatically.

Let’s look at some benchmarks. In one example, evaluating a function called W-tilde, the initial vectorized implementation was computationally infeasible, as it required creating an intermediate array of about 700 petabytes. The final result only requires about 40 gigabytes, so the problem is manageable if you avoid expanding the largest dimension in memory. The final solution in Numba uses prange for a parallel reduction, similar to OpenMP. In JAX, which lacks prange and immutable data structures, the idiomatic solution is to use jax.lax.scan to build up the sum iteratively.

On a single CPU core, the JAX implementation is slightly faster than the Numba version. However, changing the underlying algorithm yields a much larger speedup (a factor of four) even within Numba. This highlights that algorithmic improvements can be more impactful than framework changes.

When comparing Numba on 128 CPU cores to JAX on an A100 GPU, JAX is significantly faster for the original algorithm, as expected. However, for an alternative algorithm, JAX was slower, likely because the problem size was not large enough to saturate the GPU. In another example calculating a curvature matrix, the fastest algorithm on a single CPU core used sparse matrix operations in Numba. On the GPU, however, the dense JAX implementation was the fastest by a large margin.

Lessons Learned and Limitations

The key takeaway is that improving the algorithm within Numba can often provide significant speedups without needing to switch frameworks. The only fair comparison between Numba and JAX is on a single CPU core. The best algorithm often depends on the input size and the hardware, so it can be advantageous to maintain multiple implementations and profile them for each specific science case and system.

A major limitation of JAX is its multi-threading support on CPUs, which is “survivable” at best. This reflects a lack of interest from the primarily machine-learning-focused community. In an HPC setting where you might use multiple levels of parallelism (e.g., SIMD, multithreading with OpenMP, multiprocessing with MPI), controlling thread counts in JAX is obscure. While there are workarounds, they have downsides. Forcing JAX to treat each CPU core as a separate device requires a distributed memory model with data sharding, and preliminary testing has not been promising for achieving good scaling.

So, does this solve the three-implementation problem? No. We are left with a Numba implementation that is better for multi-core CPUs and a JAX implementation that is better for accelerators. Even on a single CPU, Numba can sometimes be faster because JAX is a more restrictive language. The best approach appears to be keeping both implementations and targeting different use cases.

Conclusion

In summary, cosmology is amazing, and JAX is a fun, powerful language with a pure functional paradigm that makes it easy to achieve great speedups and deploy on GPUs. However, it is not a silver bullet for all HPC use cases, particularly those relying heavily on multi-core CPU parallelism. Thank you.