Hidden Assumptions in jax

13 assumptions this code never checks · 5 critical · spanning Environment, Shape, Resource, Domain, Ordering, Scale, Contract, Temporal

Every codebase relies on things it never checks. Most of them are routine. CodeSea looked at jax-ml/jax and picked out the few most likely to cause trouble. The full list is just below.

Most of what this code assumes is routine. These 3 are the ones most likely to cause trouble here. The rest are minor; they're under "Show everything".

Worth your attention first

Silent failures or crashes deep in XLA when jaxlib C++ extensions have incompatible ABIs, making debugging extremely difficult

Worth your attention first

GPU kernels may fail at runtime with cryptic CUDA errors if the wrong plugin version is loaded for the actual hardware/driver combination

Worth your attention first

ctypes.cdll.LoadLibrary succeeds but pycapsule creation fails with segfaults or undefined symbols at FFI call time

Show everything (10 more)
Shape

Assumes input arrays a and b have identical shapes and dtype float32, but only validates dtype and shape equality without checking memory layout, strides, or contiguity that the underlying CUDA kernel expects

If this fails: CUDA kernel may read incorrect memory locations or crash if arrays have non-contiguous strides or unexpected memory layout

examples/ffi/src/jax_ffi_example/cuda_examples.py:foo_fwd
Resource

Assumes cudaMemcpyAsync will complete before the CUDA stream is used by subsequent operations, but doesn't add explicit synchronization or check for CUDA errors from the memcpy operation

If this fails: Race conditions where downstream GPU operations read uninitialized data if memcpy hasn't completed, leading to wrong results

examples/ffi/src/jax_ffi_example/gpu_examples.cc:StateExecute
Domain

Assumes the input array attribute contains int32_t elements that can be safely summed into an int64_t without overflow, but never checks the array size or validates that sum won't exceed int64_t max value

If this fails: Integer overflow in the sum calculation leads to wrong results when processing large arrays or arrays with large values

examples/ffi/src/jax_ffi_example/cpu_examples.cc:ArrayAttrImpl
Ordering

Assumes input and output arrays have the same memory layout and that it's safe to read from x and write to y in the same sequential order, but doesn't validate pointer alignment or check for memory overlap

If this fails: Memory corruption or undefined behavior if input and output arrays overlap in memory or have different alignment requirements

examples/ffi/src/jax_ffi_example/rms_norm.cc:ComputeRmsNorm
Scale

Assumes the size parameter fits in int64_t and that the variance calculation (sm / size) won't cause numerical instability, but doesn't validate size > 0 or handle the case where all elements are zero

If this fails: Division by zero or numerical overflow when size is 0 or when sum of squares is extremely large, leading to NaN or Inf results

examples/ffi/src/jax_ffi_example/rms_norm.cc:ComputeRmsNorm
Contract

Assumes the num parameter is a positive integer that can be used to create a valid numpy array range, but never validates num >= 0 or checks that np.arange(num) won't exhaust memory

If this fails: Memory exhaustion or negative array indices when num is very large or negative, causing crashes or wrong results

examples/ffi/src/jax_ffi_example/cpu_examples.py:array_attr
Contract

Assumes the input x is a JAX array with a numeric dtype compatible with the C++ implementation, but never validates that x.dtype maps to a supported C++ type (T in the template)

If this fails: Type mismatch between Python array dtype and C++ template instantiation leads to incorrect memory interpretation and wrong results

examples/ffi/src/jax_ffi_example/rms_norm.py:rms_norm
Environment

Assumes jaxlib.mlir.dialects module exists and contains all expected dialect submodules (arith, mhlo, etc.), but uses lazy loading without validating module completeness

If this fails: AttributeError at runtime when trying to access missing dialect modules, breaking MLIR lowering for certain operations

jax/_src/lib/mlir/dialects/__init__.py:lazy_loading
Temporal

Assumes the State object lifetime extends beyond the StateExecute call and that the memory pointed to by state->value remains valid during the asynchronous CUDA memcpy

If this fails: Use-after-free or reading invalid memory if the State object is destroyed before the CUDA stream completes the memcpy operation

examples/ffi/src/jax_ffi_example/gpu_examples.cc:State
Domain

Assumes eps=1e-5 is an appropriate epsilon value for numerical stability across all input scales and dtypes, but doesn't adjust epsilon based on the working precision (float32 vs float64)

If this fails: Numerical instability for very small values with float32 or insufficient precision for float64 computations

examples/ffi/src/jax_ffi_example/rms_norm.py:eps_parameter

See the full structural analysis of jax: the pipeline, data models, and system behavior that put these assumptions in context.

Full analysis of jax-ml/jax →

Frequently Asked Questions

What does jax assume that could break in production?

The one most likely to cause trouble: Assumes jaxlib is properly installed and its version is compatible, but only checks version match after successful import — incompatible jaxlib versions that import successfully but have ABI mismatches will cause cryptic runtime errors in XLA compilation If this fails, Silent failures or crashes deep in XLA when jaxlib C++ extensions have incompatible ABIs, making debugging extremely difficult

How many hidden assumptions does jax have?

CodeSea found 13 assumptions jax relies on but never validates, 5 of them critical, spanning Environment, Shape, Resource, Domain, Ordering, Scale, Contract, Temporal. Most are routine — the analysis flags the two or three most likely to actually bite.

What is a hidden assumption?

Something the code depends on but never checks: a data shape, an ordering, an environment condition, a scale limit, or a contract with another service. It holds until the world it runs in changes, then fails silently.