This is where JAX hopes to fill those gaps. JAX is a library from Google Research designed to overcome these hurdles, effectively bringing NumPy functionality into the world of modern hardware accelerators and gradient-based optimisation.
Please note that the JAX library is a research project, not an official Google product. As their GitHub repo indicates, expect “sharp edges” and potential bugs.
What is NumPy?
If you’re involved in scientific Python, you are likely familiar with NumPy. It provides the core ndarray object for efficient storage and manipulation of dense numerical arrays and a vast library of mathematical functions optimised (often via C, C++, or Fortran code) to operate on these arrays much faster than pure Python loops.
It’s the bedrock upon which much of the scientific Python ecosystem (SciPy, Pandas, Scikit-learn, Matplotlib) is built. However, its primary design targets CPU execution rather than GPUs and doesn’t inherently support automatic gradient calculation.
That last point is vital because gradients measure how a function’s output changes with its inputs, indicating the steepest increase or decrease direction. This information is crucial for optimisation algorithms, particularly in machine learning, where gradients guide the adjustment of model parameters to minimise error during training. Calculating gradients efficiently enables large language models to learn complex patterns from data.
What is JAX?
JAX is a high-performance numerical computing library for Python developed by Google. It combines NumPy-like APIs, automatic differentiation (autodiff), and accelerated hardware execution on Graphical or Tensor Processing Units (GPU/TPU).
Its main features include,
jax.numpy.A drop-in replacement for NumPy (same API, but works with GPU/TPU)jax.grad.An automatic differentiation of functions (like TensorFlow or PyTorch)jax.jit.A just-in-time compilation via Accelerated Linear Algebra (XLA) libraries for blazing-fast executionjax.vmap,jax.pmap.Automatic vectorisation and parallelisationGPU/TPU support.Runs seamlessly on accelerators without changing your code
Why Use JAX?
You should consider JAX when you need to …
- Run NumPy-like computations significantly faster by leveraging GPUS or TPUS.
- Automatically calculate gradients of your numerical Python functions for optimisation (machine learning, physics simulations, etc.).
- Achieve further speedups by JIT-compiling critical Python code sections into optimised XLA executables.
- Easily vectorise functions to handle batches of data or parallelise computations across multiple accelerator devices.
Before deciding to use JAX over NumPy, you need to know that there are also some key differences between the two libraries. While jax.numpy mimics the NumPy API, note the following differences:
1. Execution Backend & Compilation.
NumPy executes eagerly on the CPU, typically using pre-compiled C, C++ or Fortran extensions and optimised linear algebra libraries like OpenBLAS.
JAX uses the XLA compiler to translate JAX code into optimised machine code for CPU, GPU, or TPU. Execution can be Just-In-Time compiled using jax.jit and is often dispatched asynchronously.
2. Execution Model.
NumPy operations generally execute synchronously — the Python interpreter waits for the operation to complete before moving on.
JAX operations are dispatched asynchronously to the accelerator. Python code may continue running while the computation is happening. You often need result.block_until_ready() for accurate timing or to ensure a result is available before using it elsewhere (e.g., printing). jax.jit adds a compilation step on the first call.
3. Mutability.
NumPy arrays (ndarray) are mutable. You can change elements in place (e.g., a[0] = 100).
JAX arrays are immutable, and in-place updates are not allowed. This functional approach is crucial for JAX’s transformations to work reliably without side effects. Updates require creating new arrays using indexed update syntax.
4. Random Number Generation.
NumPy uses a global random number generator state (np.random.seed(), np.random.rand()). This can be problematic for reproducibility in parallel or transformed code.
JAX requires explicit handling of random keys. You must manually manage and split keys to ensure reproducible randomness.
5. API Coverage.
NumPy has a comprehensive API that covers many areas of numerical computing.
JAX covers a large and growing subset of the most common NumPy API, but is not a 100% drop-in replacement. Some less common functions, certain data types (like object arrays), or specific behaviours might differ or be missing.
Pre-requisites
Ok, let’s get started and see some coding examples. I’m using WSL2 Ubuntu for Windows for development.
I’m lucky to have an Nvidia GPU on my system, so I’ll target that in my code. You can still enjoy accelerated code over NumPy on your CPU if you don’t have a GPU. If that’s you or you have a different GPU make, check out the official documentation (link at the end of this article) for instructions on installing JAX for your system.
The first step is to create our environment. I use conda for this, but feel free to use whatever method you’re most comfortable with.
(base) $ conda create -n jax_test python=3.13 -y
Now, activate it and install any necessary libraries. Note that since I’m installing for an NVIDIA GPU, the appropriate NVIDIA drivers and a CUDA environment, e.g., CUDA 11 or CUDA 12, must also be installed. I won’t go into detail here, but a comprehensive explanation is available on the NVIDIA website.
Next, activate your new environment and install the required external libraries.
(base) $ conda activate jax_test
(jax_test) $ pip install jupyter numpy "jax[cuda12]" matplotlib pillow
Installing JAX was a lengthy process on my system, but it will eventually end, and you can then start a Jupyter notebook. You should see a notebook open in your browser. If that doesn’t happen automatically, you’ll likely see a screenful of information after the jupyter notebook command. Near the bottom, you will find a URL to copy and paste into your browser to launch the Jupyter Notebook.
Your URL will be different to mine, but it should look something like this:-
http://127.0.0.1:8888/tree?token=3b9f7bd07b6966b41b68e2350721b2d0b6f388d248cc69d
Code example 1 — A Familiar API & JIT Compilation
Our first example shows how we can enhance performance over NumPy through just-in-time compilation. This example illustrates the SELU (Scaled Exponential Linear Unit) function applied to a 10,000 x 10,000 array. SELU is a popular activation function for self-normalising neural networks, defined as,

This is implemented using np.where (NumPy) or jnp.where (JAX), which chooses different formulas for positive and negative values of x.
The code showcases three implementations
selu_numpy(x)— regular NumPy versionselu_jax(x)— JAX version (same code, but using JAX arrays)selu_jax_jit(x)— same as above, but wrapped in@jax.jitto compile and speed up the function
The code generates a large dataset: a 10,000 x 10,000 array of random numbers. Then it times how long each implementation takes to run:
selu_numpy: plain NumPy — runs directly on the CPUselu_jax: JAX without JIT — slower because it interprets the functionselu_jax_jit: JAX with JIT — compiles the function the first time, then runs super fast on subsequent runs
Note:
block_until_ready()is used in JAX to wait for the operation to finish, since JAX runs asynchronously- During the first JIT run, there is extra time due to compilation.
- On the second JIT run, the compiled function is reused, making it significantly faster.
import numpy as np
import jax
import jax.numpy as jnp
from timeit import default_timer as timer
# Define constants for SELU
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
# --- NumPy version ---
def selu_numpy(x):
return scale * np.where(x > 0, x, alpha * np.exp(x) - alpha)
# --- JAX version ---
def selu_jax(x):
return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
# --- JIT-compiled JAX version ---
# Apply the @jax.jit decorator
@jax.jit
def selu_jax_jit(x):
return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
# Generate some data
x_np = np.random.rand(10000, 10000).astype(np.float32)
# Use JAX's random number generation (requires a key)
key = jax.random.PRNGKey(0)
x_jax = jax.random.normal(key, (10000, 10000), dtype=jnp.float32) # Use JAX random for consistency
print("Running benchmarks...")
# --- Benchmarking ---
# NumPy
start = timer()
result_np = selu_numpy(x_np)
# No need to block for NumPy as it's synchronous on CPU
print(f"NumPy time: {timer()-start:.6f} seconds")
# JAX (without JIT) - first run might have slight overhead
start = timer()
result_jax = selu_jax(x_jax)
result_jax.block_until_ready() # IMPORTANT: Wait for JAX computation to finish
print(f"JAX (no jit) time: {timer()-start:.6f} seconds")
# JAX (with JIT) - First run (includes compilation time)
start = timer()
result_jax_jit = selu_jax_jit(x_jax)
result_jax_jit.block_until_ready()
print(f"JAX (jit) first run time (incl. compile): {timer()-start:.6f} seconds")
# JAX (with JIT) - Second run (uses cached compiled code)
start = timer()
result_jax_jit_2 = selu_jax_jit(x_jax)
result_jax_jit_2.block_until_ready()
print(f"JAX (jit) second run time: {timer()-start:.6f} seconds")
# Verify results are close
print(np.allclose(selu_numpy(np.array(x_jax)), result_jax_jit_2, atol=1e-6)) # Should be True
And the output.
Running benchmarks...
NumPy time: 0.357104 seconds
JAX (no jit) time: 0.108734 seconds
JAX (jit) first run time (incl. compile): 0.026956 seconds
JAX (jit) second run time: 0.002400 seconds
True
If my maths is good, I make that second JIT run over 100x faster than the NumPy run! Even the no-jit run was 3 times faster than NumPy.
Example 2: Automatic Differentiation (jax.grad)
Shows how easily JAX computes gradients.
import jax
import jax.numpy as jnp
# Define the function using jax.numpy
def cubic_sum(x):
return jnp.sum(x**3)
# Get the gradient function using jax.grad
grad_cubic_sum = jax.grad(cubic_sum)
# Create some input data
x_input = jnp.arange(1.0, 5.0) # JAX array
# Calculate the gradient
gradient = grad_cubic_sum(x_input)
print(f"\n--- Autodiff Example ---")
print(f"Original function input: {x_input}")
print(f"Function output f(x): {cubic_sum(x_input)}")
print(f"Gradient df/dx: {gradient}") # Should be [ 3. 12. 27. 48.]
--- Autodiff Example ---
Original function input: [1. 2. 3. 4.]
Function output f(x): 100.0
Gradient df/dx: [ 3. 12. 27. 48.]
Expected gradient: [ 3. 12. 27. 48.]
The original function we’re analysing is x³. Its first derivative is 3x². So, for each input, i.e., 1,2,3,4, we calculate the value of the first derivative. So we end up with,
3 x 1² = 3
3 x 2² = 12
3 x 3² = 27
3 x 4² = 48
Example 3: Matrix multiplication using jax.vmap
This highlights the performance difference in vectorised operations as we multiply a 10,000-element matrix by a batch of 128 x 10,000-element vectors.
import numpy as np
import jax
import jax.numpy as jnp
from timeit import default_timer as timer
# --- Function for a single data point ---
# Multiply a matrix (M) by a vector (v)
def mat_vec_product(matrix, vector):
return jnp.dot(matrix, vector)
# --- Create batched version using vmap ---
# We want to apply mat_vec_product to each vector in a batch.
# The matrix stays the same for all vectors in the batch.
# in_axes=(None, 0) means:
# None: Don't map over the first argument (matrix), broadcast it.
# 0: Map over the first axis (axis 0) of the second argument (the batch of vectors).
batched_mat_vec = jax.vmap(mat_vec_product, in_axes=(None, 0))
# --- JIT compile the vmapped function for performance ---
@jax.jit
def batched_mat_vec_jit(matrix, vectors):
# Note: vmap is often combined with jit
return jax.vmap(mat_vec_product, in_axes=(None, 0))(matrix, vectors)
# --- Setup Data ---
matrix_size = 10000
vector_size = 10000
batch_size = 128
dtype = jnp.float32
key = jax.random.PRNGKey(0)
key, subkey1, subkey2 = jax.random.split(key, 3)
# JAX data
matrix_jax = jax.random.normal(subkey1, (matrix_size, vector_size), dtype=dtype)
vectors_jax = jax.random.normal(subkey2, (batch_size, vector_size), dtype=dtype) # Batch first
# NumPy data
matrix_np = np.array(matrix_jax)
vectors_np = np.array(vectors_jax)
print(f"\n--- vmap Benchmark (Matrix: {matrix_size}x{vector_size}, Batch Size: {batch_size}) ---")
print(f"JAX devices available: {jax.devices()}")
# --- Benchmarking ---
# NumPy Approach 1: Python Loop (Illustrative, usually slow)
start_np_loop = timer()
output_np_loop = np.array([np.dot(matrix_np, v) for v in vectors_np])
end_np_loop = timer()
print(f"NumPy (Python loop) time: {end_np_loop - start_np_loop:.6f} seconds")
# NumPy Approach 2: Matmul with transpose (Another efficient way)
start_np_matmul = timer()
# Need vectors_np to be (vector_size, batch_size) for matmul
output_np_matmul = (matrix_np @ vectors_np.T).T
end_np_matmul = timer()
print(f"NumPy (matmul @) time: {end_np_matmul - start_np_matmul:.6f} seconds")
# JAX vmap (no jit)
start_jax_vmap = timer()
output_jax_vmap = batched_mat_vec(matrix_jax, vectors_jax)
output_jax_vmap.block_until_ready()
end_jax_vmap = timer()
print(f"JAX (vmap, no jit) time: {end_jax_vmap - start_jax_vmap:.6f} seconds")
# JAX vmap (jit) - First run (compilation)
start_jax_vmap_jit_compile = timer()
output_jax_vmap_jit_compile = batched_mat_vec_jit(matrix_jax, vectors_jax)
output_jax_vmap_jit_compile.block_until_ready()
end_jax_vmap_jit_compile = timer()
print(f"JAX (vmap+jit) first run (incl. compile): {end_jax_vmap_jit_compile - start_jax_vmap_jit_compile:.6f} seconds")
# JAX vmap (jit) - Second run
start_jax_vmap_jit = timer()
output_jax_vmap_jit = batched_mat_vec_jit(matrix_jax, vectors_jax)
output_jax_vmap_jit.block_until_ready()
end_jax_vmap_jit = timer()
print(f"JAX (vmap+jit) second run time: {end_jax_vmap_jit - start_jax_vmap_jit:.6f} seconds")
Here is the output.
--- vmap Benchmark (Matrix: 10000x10000, Batch Size: 128) ---
JAX devices available: [CudaDevice(id=0)]
NumPy (Python loop) time: 1.129315 seconds
NumPy (matmul @) time: 0.029319 seconds
JAX (vmap, no jit) time: 0.901569 seconds
JAX (vmap+jit) first run (incl. compile): 0.539354 seconds
JAX (vmap+jit) second run time: 0.001776 seconds
As you can see, once again, the first JIT compilation cycle took a long time compared to the others, but the speedup after that is impressive.
Example 4: Image Convolution (Gaussian Blur)
For a more practical example, let’s examine convolution. Convolution is a fundamental operation in image processing, used for tasks such as blurring, sharpening, and edge detection. It involves sliding a small matrix (kernel) over the image and computing a weighted sum of the pixels under the kernel at each position. We’ll implement a basic version of Gaussian blurring using array slicing and element-wise operations to see how jax.jit can optimize this sequence.
Here is my input image.

In simple terms, the code does the following:
Convert the colour input image to greyscale, then load it into a NumPy array and normalise it.
Defines a 2D Gaussian kernel of configurable size and standard deviation.
Implements a manual 2D convolution function in three ways:
- Pure NumPy on the CPU
- JAX array operations (before JIT compilation)
- JAX with
@jax.jitfor optimised compilation (with one warm-up run and one timed run)
Benchmarks each variant, printing timings and checking that the blurred outputs match within float tolerances.
Displays the original (grayscale) and blurred images side by side using Matplotlib, allowing you to confirm the blur visually.
import numpy as np
import jax
import jax.numpy as jnp
from timeit import default_timer as timer
from PIL import Image # Import Pillow
import matplotlib.pyplot as plt
import os # To check if file exists
# --- Configuration ---
image_path = "/mnt/d/images/taj_mahal.png"
kernel_size = 9 # Increased kernel size slightly for more visible blur
sigma = 2.5
dtype = jnp.float32
# --- Check if image file exists ---
if not os.path.exists(image_path):
print(f"ERROR: Image file not found at '{image_path}'")
print("Please update the 'image_path' variable in the script.")
exit() # Stop execution if image is not found
# --- Load and Prepare Image ---
print(f"Loading image from: {image_path}")
try:
# Open image, convert to grayscale ('L'), then to NumPy array
with Image.open(image_path) as img:
image_np_uint8 = np.array(img.convert('L')) # Convert to grayscale uint8
# Normalize to float32 between 0.0 and 1.0
image_np = image_np_uint8.astype(np.float32) / 255.0
image_jax = jnp.array(image_np) # Convert to JAX array
image_size_h, image_size_w = image_np.shape
print(f"Image loaded successfully ({image_size_h}x{image_size_w})")
except Exception as e:
print(f"ERROR: Failed to load or process image '{image_path}'. Error: {e}")
exit()
# --- Define a simple Gaussian kernel ---
def gaussian_kernel(size, sigma=1.0):
"""Creates a 2D Gaussian kernel using JAX."""
ax = jnp.arange(-size // 2 + 1., size // 2 + 1.)
xx, yy = jnp.meshgrid(ax, ax)
kernel = jnp.exp(-(xx**2 + yy**2) / (2. * sigma**2))
return (kernel / jnp.sum(kernel)).astype(dtype) # Normalize and set dtype
# --- Convolution implementation using basic array ops ---
def convolve_2d_manual(image, kernel):
im_h, im_w = image.shape
ker_h, ker_w = kernel.shape
pad_h, pad_w = ker_h // 2, ker_w // 2
padded_image = jnp.pad(image, ((pad_h, pad_h), (pad_w, pad_w)), mode='edge')
output = jnp.zeros_like(image)
for i in range(ker_h):
for j in range(ker_w):
image_slice = jax.lax.dynamic_slice(padded_image, (i, j), (im_h, im_w)) # Use dynamic_slice for JIT
output += kernel[i, j] * image_slice
return output
# --- JIT-compiled version ---
@jax.jit
def convolve_2d_manual_jit(image, kernel):
# Identical logic, but JIT will optimize it
im_h, im_w = image.shape
ker_h, ker_w = kernel.shape
pad_h, pad_w = ker_h // 2, ker_w // 2
padded_image = jnp.pad(image, ((pad_h, pad_h), (pad_w, pad_w)), mode='edge')
output = jnp.zeros_like(image)
# Unrolling loop slightly for potentially better JIT tracing (optional)
for i in range(ker_h):
for j in range(ker_w):
# Use jax.lax.dynamic_slice for compatibility with JIT when slice sizes are dynamic
image_slice = jax.lax.dynamic_slice(padded_image, (i, j), (im_h, im_w))
output += kernel[i, j] * image_slice
return output
# --- NumPy equivalent for comparison ---
def convolve_2d_manual_np(image, kernel):
im_h, im_w = image.shape
ker_h, ker_w = kernel.shape
pad_h, pad_w = ker_h // 2, ker_w // 2
padded_image = np.pad(image, ((pad_h, pad_h), (pad_w, pad_w)), mode='edge')
output = np.zeros_like(image)
for i in range(ker_h):
for j in range(ker_w):
image_slice = padded_image[i:i + im_h, j:j + im_w]
output += kernel[i, j] * image_slice
return output
# --- Setup Kernel ---
kernel_jax = gaussian_kernel(kernel_size, sigma=sigma)
kernel_np = np.array(kernel_jax) # Copy for NumPy
print(f"\n--- Convolution Benchmark (Image: {image_size_h}x{image_size_w}, Kernel: {kernel_size}x{kernel_size}) ---")
print(f"JAX devices available: {jax.devices()}")
# --- Benchmarking ---
# NumPy (CPU)
start_np = timer()
output_np = convolve_2d_manual_np(image_np, kernel_np)
end_np = timer()
print(f"NumPy (manual conv) time: {end_np - start_np:.6f} seconds")
# JAX (no jit) - Run once to warm up if needed
start_jax = timer()
output_jax = convolve_2d_manual(image_jax, kernel_jax)
output_jax.block_until_ready()
end_jax = timer()
print(f"JAX (no jit, manual conv) time: {end_jax - start_jax:.6f} seconds")
# JAX (jit) - First run (compilation)
start_jax_compile = timer()
output_jax_jit_compile = convolve_2d_manual_jit(image_jax, kernel_jax)
output_jax_jit_compile.block_until_ready()
end_jax_compile = timer()
print(f"JAX (jit, manual conv) first run (incl. compile): {end_jax_compile - start_jax_compile:.6f} seconds")
# JAX (jit) - Second run
start_jax_jit = timer()
output_jax_jit = convolve_2d_manual_jit(image_jax, kernel_jax)
output_jax_jit.block_until_ready()
end_jax_jit = timer()
print(f"JAX (jit, manual conv) second run time: {end_jax_jit - start_jax_jit:.6f} seconds")
# Verify results (expect some difference due to float32 accumulation)
max_diff_conv = np.max(np.abs(output_np - output_jax_jit))
print(f"Convolution max absolute difference: {max_diff_conv:.6f}")
# Use appropriate tolerances for float32 convolution
print(f"Convolution results close (atol=1e-3, rtol=1e-3): {np.allclose(output_np, output_jax_jit, atol=1e-3, rtol=1e-3)}")
# --- Visualization ---
print("\n--- Visualizing Input and Output ---")
fig, axes = plt.subplots(1, 2, figsize=(12, 6)) # Adjusted figure size
# Display the original grayscale image
axes[0].imshow(image_np, cmap='gray', vmin=0, vmax=1) # Use the loaded numpy image
axes[0].set_title('Original Grayscale Image (Input)')
axes[0].axis('off')
# Display the blurred output image (using the NumPy result for consistency in visualization)
axes[1].imshow(output_np, cmap='gray', vmin=0, vmax=1)
axes[1].set_title(f'Blurred Image (Output, Kernel Size={kernel_size})')
axes[1].axis('off')
plt.tight_layout()
plt.show()
Let’s see the outputs. Note that the input image, which was in colour, is converted to black and white. This is normal behaviour for this process.
Loading image from: /mnt/d/images/taj_mahal.png
Image loaded successfully (473x716)
--- Convolution Benchmark (Image: 473x716, Kernel: 9x9) ---
JAX devices available: [CudaDevice(id=0)]
NumPy (manual conv) time: 0.025815 seconds
JAX (no jit, manual conv) time: 0.234791 seconds
JAX (jit, manual conv) first run (incl. compile): 0.366345 seconds
JAX (jit, manual conv) second run time: 0.000238 seconds
Convolution max absolute difference: 0.000000
Convolution results close (atol=1e-3, rtol=1e-3): True
Again, we see a 100x speedup on the second JIT run. Amazing.
And here are the output images.

Summary
JAX represents a significant evolution for high-performance numerical computing in Python. By providing a NumPy-like interface combined with powerful function transformations (grad, jit, vmap and others) and efficient execution on accelerators via accelerated linear algebra (XLA), it unlocks capabilities essential for modern machine learning and large-scale scientific computation.
Its dramatic performance gains and built-in automatic differentiation could make it an indispensable tool for researchers and engineers pushing the boundaries of computational science. If your NumPy code encounters performance bottlenecks or requires gradients, JAX may offer a compelling path forward.
As I said at the beginning, this is an experimental project from Google. Will they crack on and make it an official Google product? Who knows? But if they do, NumPy might start having to look over its shoulder at this new kid on the block. However, there is already a significant amount of NumPy code in existence, so it may take JAX many years to supplant it, even if it does gain widespread adoption. And you can be sure that, in the meantime, the NumPy development team will not be sitting on their hands.
Here is a link to the official JAX documentation.





