JIT compilers for scientific computing in Python: Numba vs. JAX
JIT Compilers for Scientific Computing in Python: Numba vs. JAX
Python is widely utilized in the scientific community, but when it comes to High-Performance Computing (HPC), pure Python is often too slow. For researchers processing immense amounts of data—such as performing cosmological inference on the Cosmic Microwave Background (CMB) using up to a petabyte of data on supercomputers like NERSC—performance is critical.
To bridge the gap between Python’s usability and the raw speed required for massive calculations, Just-In-Time (JIT) compilers are essential. This article explores the necessity of Python in HPC, the power of JIT compilers, and a real-world case study comparing Numba and JAX for a gravitational lensing likelihood calculation.
Why Python for High-Performance Computing?
At first glance, Python might seem like a counterintuitive choice for state-of-the-art supercomputing. However, it thrives in HPC environments for several key reasons:
The Community Factor: The scientific community heavily relies on the Python ecosystem, including libraries like NumPy, SciPy, and Astropy.
The Ultimate “Glue” Language: Python excels at wrapping and orchestrating code written in fundamentally faster languages.
Language Agnosticism: Researchers often do not care what language an algorithm is written in (whether C, C++, Fortran, or Julia) as long as it can be accessed easily via a Python interface.
The Power of JIT: Solving the Two-Language Problem
Traditionally, accelerating Python required rewriting critical numerical code in a lower-level language like C or C++ and exposing it via APIs (e.g., using PyBind11). This creates a “two-language problem”. Translating a simple, readable NumPy function into C++ can demand hundreds of lines of boilerplate code for only modest performance gains—for example, achieving a mere 30% speedup after adding over 200 lines of code.
JIT compilers like Numba and JAX solve this by allowing researchers to compile a subset of numerical Python code “just-in-time”. By adding a simple @jit decorator, code can run near native speeds, sometimes achieving immediate 3x speedups without leaving Python.
Case Study: The Memory Explosion Problem
To understand how Numba and JAX differ, consider the implementation of the following equation, which is used to calculate weights for a likelihood function:
\[\tilde{w}_{ij} = \sum_{k=1}^K \frac{1}{n_k^2} \cos \left( 2 \pi \left[ \left(\vec{g}_i - \vec{g}_j \right) \cdot \vec{u}_k \right] \right) ,\quad 1 \leq i, j \leq M\]
In this scenario, the number of image pixels (\(M\)) can be up to 70,000, and the number of visibilities (\(K\)) can reach \(10^7\).
If implemented using standard, fully vectorized NumPy, the intermediate arrays generated across the \(K\) dimension would require hundreds of petabytes of memory, far exceeding the capacity of even the largest supercomputers.
The Numba Solution
To solve this, the memory footprint must be reduced by looping over the \(K\) dimension rather than expanding it. Numba excels here. Because it supports a C-like programming paradigm, you can write an explicit loop and use numba.prange to parallelize the in-place reduction sum efficiently.
The JAX Solution
Migrating this same loop to JAX requires a conceptual shift. JAX enforces a pure functional programming paradigm and explicitly prohibits mutating states or in-place manipulations. To achieve the same low-memory iterative evaluation in JAX, the logic must be refactored using functional idiomatic constructs, such as jax.lax.scan, to iterate and accumulate the sum without mutating existing arrays.
Numba vs. JAX: A High-Level Comparison
While both tools compile Python to machine code, their architectures and design philosophies differ significantly.
| Feature | Numba | JAX |
|---|---|---|
| Compiler Backend | Powered by LLVM. | Powered by XLA (Accelerated Linear Algebra). |
| Hardware Targets | Primarily CPU-focused (CUDA interface is separate). | Targets CPU, GPU, and TPU simultaneously from the same code. |
| Programming Paradigm | C-like mini-language allowing loops and mutation. | Functional paradigm enforcing no side effects or mutated states. |
| Recompilation Triggers | Recompiles when input types change. | Recompiles when input types or shapes change. |
| Secret Weapons | Acts as a drop-in replacement for a subset of NumPy operations. | Features Automatic Differentiation (jax.grad) for complex optimization. |
Real-World Impact: PyAutoLens
The differences between Numba and JAX are highlighted in PyAutoLens, a complex astrophysics pipeline used for analyzing strong gravitational lensing data.
The goal of this pipeline is Maximum Likelihood Estimation (MLE): finding the exact peak of a likelihood function across 25 complex, free parameters (encompassing lens light, lens mass, and source light) to accurately model observed data.
JAX provides a massive advantage for this specific workflow through Automatic Differentiation (autodiff). Because the JAX compiler transforms the code into a computational graph, it can also automatically transform the function to compute its exact gradient. These gradients point optimization algorithms directly toward the peak of the likelihood function, vastly accelerating the fitting process.
Ultimately, migrating the PyAutoLens project from Numba to JAX was highly successful, resulting in a staggering 50x performance speedup in execution. While both JIT compilers provide incredible tools for the scientific Python community, JAX’s ability to seamlessly target GPUs and automatically compute gradients makes it an unparalleled engine for modern scientific machine learning and cosmological research.