Differentiable Quantum Simulation at TPU Scale: Inside the JAX Quantum Research Suite

1 2 5
calendar_today agoschedule4 min read

Differentiable Quantum Simulation at TPU Scale: Inside the JAX Quantum Research Suite

Classical simulation of quantum systems is the bedrock of quantum algorithm design. However, as researchers push the limits of the Noisy Intermediate-Scale Quantum (NISQ) era, they face two major roadblocks: memory scaling limits for large qubit counts, and the gradient bottleneck in training variational quantum algorithms.

To address these challenges, we introduce the JAX Quantum Research Suite—a high-performance, differentiable quantum circuit simulator designed to run seamlessly from entry-level consumer GPUs up to distributed Cloud TPU VM clusters.

This work and its large-scale benchmarks are supported by the Google TPU Research Cloud (TRC) Program.


The Core Philosophy: Why JAX & Cloud TPUs?

Traditional simulators like Google's qsim or NVIDIA's cuQuantum are highly optimized in C++/CUDA but lack native, end-to-end differentiability. On the other hand, frameworks like PennyLane often rely on the Parameter-Shift Rule (PSR) to compute gradients. For a circuit with $P$ parameters, PSR requires $2P$ distinct circuit evaluations, creating an $O(P)$ bottleneck that makes training deep variational circuits extremely slow.

By building our simulator in pure JAX (jax.numpy and jax.lax), we compile quantum gate operations directly into XLA (Accelerated Linear Algebra) High-Level Operations bytecode. This grants us four major capabilities out of the box:

  1. Monolithic JIT Compilation (@jax.jit): Bypasses Python interpreter overhead, running at compiled speeds.
  2. Reverse-Mode Automatic Differentiation (jax.grad): Computes all $P$ gradients in a single backward pass—an $O(1)$ scaling factor.
  3. Seamless Vectorization (jax.vmap): Parallelizes execution over batches of parameters or quantum trajectories (e.g., Monte Carlo noise simulation) in a single fused hardware kernel.
  4. Out-of-the-Box Multi-Device Distribution (jax.sharding.PositionalSharding): Splits massive statevectors across distributed HBM without writing a single line of C++ or MPI code.

Three Key Engineering Contributions for Scale

Simulating large systems or deep circuits on TPUs introduces severe compiler and memory limitations. We solved these with three targeted engineering optimizations:

1. Multi-Device Positional Sharding

A 33-qubit statevector requires 64 GB of memory, while a 36-qubit statevector requires 549.76 GB. To scale beyond the memory limits of a single device, we partition the statevector's leading dimensions across physical TPU chips using PositionalSharding. Local gates run at maximum memory bandwidth, and cross-shard operations are handled via XLA collective communication over TPU Inter-Chip Interconnects (ICI).

2. $O(1)$ Compiler Graph Size via lax.fori_loop

Standard Python loops inside JIT compilation unroll fully. For deep circuits, this results in millions of HLO nodes, causing the compiler host CPU to crash with Out-Of-Memory (OOM) errors before execution even begins. By wrapping our layers in jax.lax.fori_loop, we compile the step logic once, keeping the compiler graph size $O(1)$ regardless of circuit depth.

3. $O(1)$ Backpropagation Memory via jax.checkpoint

Evaluating gradients in reverse-mode autodiff requires storing intermediate states for the backward pass. For large-scale multi-layer variational circuits, this quickly runs out of High Bandwidth Memory (HBM). By wrapping layers in jax.checkpoint, intermediate states are discarded during the forward pass and recomputed on-the-fly, reducing the memory footprint from $O(\text{depth})$ to $O(1)$ at the cost of a single extra forward pass.


Benchmark Highlights

We validated our simulator across three hardware tiers: an entry-level consumer GPU (NVIDIA RTX 2050, 4 GB VRAM), a 16-chip Cloud TPU v5e mesh, and a 64-chip Cloud TPU v6e cluster.

1. Algorithmic Gradient Speedups

On a 15-qubit Hardware-Efficient Ansatz (120 trainable parameters) on CPU:

  • jax.grad (Reverse-Mode AD): Computes all gradients in 37.5 ms (stable post-JIT mean).
  • Parameter-Shift Rule (PSR): Requires 1,826 ms to compute the same gradients.
  • Result: A 48.7× gradient speedup over PSR. On a smaller 50-parameter GPU circuit, this gap expands to ~75×.
  • vs. PennyLane JAX: Our simulator achieves a 4× speedup over PennyLane's own JAX backend because our monolithic kernel unrolling avoids Python-level dispatch overheads.

2. Maximum Qubit Thresholds by Hardware

By utilizing distributed memory, we pushed statevector simulation limits to their theoretical boundaries:

Hardware Max Qubits Simulated Statevector Memory Algorithm / Experiment
NVIDIA RTX 2050 (4 GB VRAM) 29 Qubits 4.29 GB Statevector Baseline
Cloud TPU v5e-16 (256 GB HBM) 33 Qubits 64.00 GB Shor's Order-Finding QFT
Cloud TPU v6e-64 (~2.0 TB HBM) 36 Qubits 549.76 GB Grover's Search Algorithm

(Note: For random circuit sampling (RCS) at 37 qubits, we utilize tensor-network amplitude sampling via TensorCircuit with a JAX backend, scaling to 2,048 samples per batch on TPU v6e-64.)

3. Matrix Product States (MPS) & SVD Stability

For systems exceeding statevector limits (512 to 1024 qubits), we implemented a differentiable Matrix Product State (MPS) simulator. Differentiating through Singular Value Decomposition (SVD) in JAX is notoriously unstable due to Wirtinger calculus singularities at degenerate singular values. We resolved this by implementing:

  • SVD Jitter: Injecting $10^{-9}$ complex noise to break singular value degeneracy.
  • Site-Level Normalization: Preventing exponential amplitude drift.
  • Wirtinger Gradient Clipping: Keeping real components bounded to avoid NaN divergence.

These fixes enabled stable convergence of a 1024-qubit MPS-VQE over 10,000 training epochs.


Get Started in 10 Lines of Code

Here is how simple it is to build a circuit, compile it, and compute exact gradients using the JAX Quantum Research Suite:

from jax_qsim.circuit import Circuit
import jax, jax.numpy as jnp

# Initialize a 4-qubit circuit
c = Circuit(num_qubits=4)
c.h(0).cnot(0, 1).ry(2, param_index=0).rz(3, param_index=1)

# Compile the execution and gradient functions
@jax.jit
def loss_fn(params):
    state = c.run(params, state_type='statevector')
    # Measure expectation value of Z on qubit 1
    return jnp.real(state[2])  # Simple toy expectation

grad_fn = jax.jit(jax.grad(loss_fn))

# Evaluate gradients instantly in a single backward pass
params = jnp.array([0.5, 1.2])
gradients = grad_fn(params)
print("Gradients:", gradients)

Acknowledgements & Open Source

This project is fully open-source and released under the MIT License.

We extend our deepest gratitude to the Google TPU Research Cloud (TRC) Program for providing the TPU v5e and TPU v6e resources that made our distributed scaling benchmarks possible.

131 Points8 Badges1 2 5
Indiaashitesh.me
1Posts
0Comments
1Followers
1Connections
Machine Learning Engineer And Assemblyx86_64 coder And Quantum Computing Researcher
Build your own developer journey
Track progress. Share learning. Stay consistent.

1 Comment

0 votes
🔥 Join developers growing publicly
Share your knowledge, build in public, and grow your developer presence with a global community.

More Posts

Just completed another large-scale WordPress migration — and the client left this

saqib_devmorph - Apr 7

Dashboard Operasional Armada Rental Mobil dengan Python + FastAPI

Masbadar - Mar 12

I Wrote a Script to Fix Audible's Unreadable PDF Filenames

snapsynapseverified - Apr 20

Defending Against AI Worms: Securing Multi-Agent Systems from Self-Replicating Prompts

alessandro_pignati - Apr 2

Building Research Infrastructure at Scale: Inside Globus's Hybrid SaaS Architecture

Tom Smithverified - Jan 26
chevron_left

Related Jobs

View all jobs →

Commenters (This Week)

2 comments
2 comments
1 comment

Contribute meaningful comments to climb the leaderboard and earn badges!