Stack

GitHub Actions runners for JAX

JAX wheels are GPU-specific — TPU/CUDA mismatch on default runners breaks `pjit` tests. Cirun runs your JAX CI on a right-sized cloud VM on your own account — pick AWS, GCP, Azure, Oracle, DigitalOcean, OpenStack or on-prem, billed by your cloud, never per CI minute.

Why this fits

  • JAX wheels are GPU-specific — TPU/CUDA mismatch on default runners breaks `pjit` tests.
  • Cirun is cloud-neutral — every supported cloud has at least one SKU that fits JAX; pick whichever account you already have.
  • Ephemeral by default — no state from the previous PR leaks into yours, no flaky-cache surprises.

.cirun.yml

.cirun.yml
1runners:
2 - name: jax-runner
3 cloud: aws
4 instance_type: p4d.24xlarge
5 # Use AWS Deep Learning AMI GPU PyTorch on Ubuntu 22.04, or the
6 # Cirun-published NVIDIA AMI.
7 machine_image: ami-04823729c75214919
8 labels:
9 - cirun-jax

Drop this in your repo root. The first workflow that requests the runner label spins this configuration up on your cloud account.

Ready to run your CI here?

Cirun is free for open source. For private repos, flat monthly plans by repo count — never per CI minute.