Differentiable Quantum Simulation at TPU Scale: Inside the JAX Quantum Research Suite
Classical simulation of quantum systems is the bedrock of quantum algorithm design. However, as researchers push the limits of the Noisy Intermediate-Scale Quantum (NISQ) era, they face two major roadblocks: memory scaling limits for large qubit counts, and the gradient bottleneck in training variational quantum algorithms.
To address these challenges, we introduce the JAX Quantum Research Suite—a high-performance, differentiable quantum circuit simulator designed to run seamlessly from entry-level consumer GPUs up to distributed Cloud TPU VM clusters.
This work and its large-scale benchmarks are supported by the Google TPU Research Cloud (TRC) Program.
The Core Philosophy: Why JAX & Cloud TPUs?
Traditional simulators like Google's qsim or NVIDIA's cuQuantum are highly optimized in C++/CUDA but lack native, end-to-end differentiability. On the other hand, frameworks like PennyLane often rely on the Parameter-Shift Rule (PSR) to compute gradients. For a circuit with $P$ parameters, PSR requires $2P$ distinct circuit evaluations, creating an $O(P)$ bottleneck that makes training deep variational circuits extremely slow.
By building our simulator in pure JAX (jax.numpy and jax.lax), we compile quantum gate operations directly into XLA (Accelerated Linear Algebra) High-Level Operations bytecode. This grants us four major capabilities out of the box:
- Monolithic JIT Compilation (
@jax.jit): Bypasses Python interpreter overhead, running at compiled speeds.
- Reverse-Mode Automatic Differentiation (
jax.grad): Computes all $P$ gradients in a single backward pass—an $O(1)$ scaling factor.
- Seamless Vectorization (
jax.vmap): Parallelizes execution over batches of parameters or quantum trajectories (e.g., Monte Carlo noise simulation) in a single fused hardware kernel.
- Out-of-the-Box Multi-Device Distribution (
jax.sharding.PositionalSharding): Splits massive statevectors across distributed HBM without writing a single line of C++ or MPI code.
Three Key Engineering Contributions for Scale
Simulating large systems or deep circuits on TPUs introduces severe compiler and memory limitations. We solved these with three targeted engineering optimizations:
1. Multi-Device Positional Sharding
A 33-qubit statevector requires 64 GB of memory, while a 36-qubit statevector requires 549.76 GB. To scale beyond the memory limits of a single device, we partition the statevector's leading dimensions across physical TPU chips using PositionalSharding. Local gates run at maximum memory bandwidth, and cross-shard operations are handled via XLA collective communication over TPU Inter-Chip Interconnects (ICI).
2. $O(1)$ Compiler Graph Size via lax.fori_loop
Standard Python loops inside JIT compilation unroll fully. For deep circuits, this results in millions of HLO nodes, causing the compiler host CPU to crash with Out-Of-Memory (OOM) errors before execution even begins. By wrapping our layers in jax.lax.fori_loop, we compile the step logic once, keeping the compiler graph size $O(1)$ regardless of circuit depth.
3. $O(1)$ Backpropagation Memory via jax.checkpoint
Evaluating gradients in reverse-mode autodiff requires storing intermediate states for the backward pass. For large-scale multi-layer variational circuits, this quickly runs out of High Bandwidth Memory (HBM). By wrapping layers in jax.checkpoint, intermediate states are discarded during the forward pass and recomputed on-the-fly, reducing the memory footprint from $O(\text{depth})$ to $O(1)$ at the cost of a single extra forward pass.
Benchmark Highlights
We validated our simulator across three hardware tiers: an entry-level consumer GPU (NVIDIA RTX 2050, 4 GB VRAM), a 16-chip Cloud TPU v5e mesh, and a 64-chip Cloud TPU v6e cluster.
1. Algorithmic Gradient Speedups
On a 15-qubit Hardware-Efficient Ansatz (120 trainable parameters) on CPU:
jax.grad (Reverse-Mode AD): Computes all gradients in 37.5 ms (stable post-JIT mean).
- Parameter-Shift Rule (PSR): Requires 1,826 ms to compute the same gradients.
- Result: A 48.7× gradient speedup over PSR. On a smaller 50-parameter GPU circuit, this gap expands to ~75×.
- vs. PennyLane JAX: Our simulator achieves a 4× speedup over PennyLane's own JAX backend because our monolithic kernel unrolling avoids Python-level dispatch overheads.
2. Maximum Qubit Thresholds by Hardware
By utilizing distributed memory, we pushed statevector simulation limits to their theoretical boundaries:
| Hardware | Max Qubits Simulated | Statevector Memory | Algorithm / Experiment |
| NVIDIA RTX 2050 (4 GB VRAM) | 29 Qubits | 4.29 GB | Statevector Baseline |
| Cloud TPU v5e-16 (256 GB HBM) | 33 Qubits | 64.00 GB | Shor's Order-Finding QFT |
| Cloud TPU v6e-64 (~2.0 TB HBM) | 36 Qubits | 549.76 GB | Grover's Search Algorithm |
(Note: For random circuit sampling (RCS) at 37 qubits, we utilize tensor-network amplitude sampling via TensorCircuit with a JAX backend, scaling to 2,048 samples per batch on TPU v6e-64.)
3. Matrix Product States (MPS) & SVD Stability
For systems exceeding statevector limits (512 to 1024 qubits), we implemented a differentiable Matrix Product State (MPS) simulator. Differentiating through Singular Value Decomposition (SVD) in JAX is notoriously unstable due to Wirtinger calculus singularities at degenerate singular values. We resolved this by implementing:
- SVD Jitter: Injecting $10^{-9}$ complex noise to break singular value degeneracy.
- Site-Level Normalization: Preventing exponential amplitude drift.
- Wirtinger Gradient Clipping: Keeping real components bounded to avoid NaN divergence.
These fixes enabled stable convergence of a 1024-qubit MPS-VQE over 10,000 training epochs.
Get Started in 10 Lines of Code
Here is how simple it is to build a circuit, compile it, and compute exact gradients using the JAX Quantum Research Suite:
from jax_qsim.circuit import Circuit
import jax, jax.numpy as jnp
# Initialize a 4-qubit circuit
c = Circuit(num_qubits=4)
c.h(0).cnot(0, 1).ry(2, param_index=0).rz(3, param_index=1)
# Compile the execution and gradient functions
@jax.jit
def loss_fn(params):
state = c.run(params, state_type='statevector')
# Measure expectation value of Z on qubit 1
return jnp.real(state[2]) # Simple toy expectation
grad_fn = jax.jit(jax.grad(loss_fn))
# Evaluate gradients instantly in a single backward pass
params = jnp.array([0.5, 1.2])
gradients = grad_fn(params)
print("Gradients:", gradients)
Acknowledgements & Open Source
This project is fully open-source and released under the MIT License.
We extend our deepest gratitude to the Google TPU Research Cloud (TRC) Program for providing the TPU v5e and TPU v6e resources that made our distributed scaling benchmarks possible.