jax-ml/jax
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Composable transformations for Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU
Python functions with NumPy arrays flow through tracers that record operations into Jaxpr intermediate representation, then get lowered to MLIR and compiled to XLA for execution on accelerators
Under the hood, the system uses 2 feedback loops, 2 data pools, 3 control points to manage its runtime behavior.
Structural Verdict
A 8-component ml training with 7 connections. 1218 files analyzed. Well-connected — clear data flow between components.
How Data Flows Through the System
Python functions with NumPy arrays flow through tracers that record operations into Jaxpr intermediate representation, then get lowered to MLIR and compiled to XLA for execution on accelerators
- Function Tracing — Python function called with Tracer inputs to record operation graph
- Jaxpr Construction — Traced operations converted to Jaxpr intermediate representation with typed variables
- MLIR Lowering — Jaxpr lowered to MLIR HLO dialect with shape/type information
- XLA Compilation — MLIR compiled to optimized XLA computation for target hardware
- Execution — Compiled function executed on CPU/GPU/TPU with actual array data
System Behavior
How the system actually operates at runtime — where data accumulates, what loops, what waits, and what controls what.
Data Pools
Compiled XLA computations cached by function signature
Stack of active transformation contexts during tracing
Feedback Loops
- Recompilation Trigger (cache-invalidation, balancing) — Trigger: Function called with new shapes/dtypes. Action: Retrace and recompile function. Exit: Matching signature found in cache.
- Gradient Computation (recursive, balancing) — Trigger: grad() transformation applied. Action: Trace forward pass then construct reverse computation. Exit: All derivatives computed.
Delays & Async Processing
- JIT Compilation (async-processing, ~seconds) — First call to jitted function has compilation overhead
- Device Transfer (async-processing, ~milliseconds) — Data transfer between host and accelerator memory
Control Points
- jax_enable_x64 (env-var) — Controls: Whether to use 64-bit precision by default. Default: False
- jax_platform_name (env-var) — Controls: Which backend to use (cpu/gpu/tpu)
- jax_jit_pjit_api_merge (feature-flag) — Controls: Whether to merge jit and pjit APIs
Technology Stack
Accelerated Linear Algebra compiler backend
Multi-Level Intermediate Representation for compilation
Array API compatibility layer
GPU kernel compilation for Pallas
Testing framework
Python/C++ bindings for FFI
Key Components
- core (module) — Central tracing system with Tracer, Jaxpr, and abstract values for program transformation
jax/_src/core.py - jit (function) — Just-in-time compilation transformation that compiles functions to XLA
jax/_src/api.py - grad (function) — Automatic differentiation transformation for computing gradients
jax/_src/api.py - vmap (function) — Vectorization transformation that batches operations over array axes
jax/_src/api.py - MLIRLoweringContext (class) — Manages lowering of JAX operations to MLIR for XLA compilation
jax/_src/interpreters/mlir.py - PallasCall (class) — Custom kernel system for GPU programming with block-level operations
jax/_src/pallas/core.py - ffi_call (function) — Foreign function interface for calling external C/C++/CUDA code
jax/_src/ffi.py - ShapedArray (class) — Abstract value representing arrays with known shape and dtype
jax/_src/abstract_arrays.py
Sub-Modules
GPU kernel programming framework with block-level operations and memory management
Demonstration of foreign function interface for C/C++/CUDA integration
Configuration
docs/autodidax2_part1.py (python-dataclass)
primal(float, unknown)tangent(float, unknown)
docs/autodidax2_part1.py (python-dataclass)
interpreter(Interpreter, unknown)primal(float, unknown)tangent(float, unknown)
docs/the-training-cookbook.py (python-dataclass)
seq_length(int, unknown) — default: 128
jax/_src/api.py (python-dataclass)
idx(int, unknown)primal(bool, unknown)
Science Pipeline
- Array Creation — jnp.array() or device_put() converts Python data to JAX arrays [arbitrary Python nested structure → JAX ShapedArray with inferred shape/dtype]
jax/_src/array.py - Function Tracing — Replace arrays with Tracers to record computation graph [concrete arrays → abstract ShapedArray values]
jax/_src/core.py - Shape Inference — Propagate abstract shapes through operations without computing values [abstract shapes → inferred output shapes]
jax/_src/abstract_arrays.py - Backend Lowering — Convert JAX operations to MLIR HLO with explicit tensor shapes [JAX abstract arrays → MLIR tensor types]
jax/_src/interpreters/mlir.py
Assumptions & Constraints
- [warning] Assumes contracting dimensions have matching sizes but only checks at runtime during execution (shape)
- [info] Assumes total size is preserved during reshape but allows -1 inference with potential size mismatches (shape)
- [warning] Broadcasting assumes compatible dimensions but shape errors only surface during tracing (shape)
Explore the interactive analysis
See the full architecture map, data flow, and code patterns visualization.
Analyze on CodeSeaRelated Ml Training Repositories
Frequently Asked Questions
What is jax used for?
Composable transformations for Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU jax-ml/jax is a 8-component ml training written in Python. Well-connected — clear data flow between components. The codebase contains 1218 files.
How is jax architected?
jax is organized into 4 architecture layers: Public API, Core Tracing, Interpreters, Backend Integration. Well-connected — clear data flow between components. This layered structure enables tight integration between components.
How does data flow through jax?
Data moves through 5 stages: Function Tracing → Jaxpr Construction → MLIR Lowering → XLA Compilation → Execution. Python functions with NumPy arrays flow through tracers that record operations into Jaxpr intermediate representation, then get lowered to MLIR and compiled to XLA for execution on accelerators This pipeline design reflects a complex multi-stage processing system.
What technologies does jax use?
The core stack includes XLA (Accelerated Linear Algebra compiler backend), MLIR (Multi-Level Intermediate Representation for compilation), NumPy (Array API compatibility layer), Triton (GPU kernel compilation for Pallas), pytest (Testing framework), nanobind (Python/C++ bindings for FFI). A focused set of dependencies that keeps the build manageable.
What system dynamics does jax have?
jax exhibits 2 data pools (XLA Executable Cache, Tracer Stack), 2 feedback loops, 3 control points, 2 delays. The feedback loops handle cache-invalidation and recursive. These runtime behaviors shape how the system responds to load, failures, and configuration changes.
What design patterns does jax use?
4 design patterns detected: Tracer Pattern, Transformation Composition, Backend Dispatch, Abstract Interpretation.
Analyzed on March 31, 2026 by CodeSea. Written by Karolina Sarna.