Hidden Assumptions in jax
13 assumptions this code never checks · 5 critical · spanning Environment, Shape, Resource, Domain, Ordering, Scale, Contract, Temporal
Every codebase relies on things it never checks. Most of them are routine. CodeSea looked at jax-ml/jax and picked out the few most likely to cause trouble. The full list is just below.
Most of what this code assumes is routine. These 3 are the ones most likely to cause trouble here. The rest are minor; they're under "Show everything".
Silent failures or crashes deep in XLA when jaxlib C++ extensions have incompatible ABIs, making debugging extremely difficult
GPU kernels may fail at runtime with cryptic CUDA errors if the wrong plugin version is loaded for the actual hardware/driver combination
ctypes.cdll.LoadLibrary succeeds but pycapsule creation fails with segfaults or undefined symbols at FFI call time
Show everything (10 more)
Assumes input arrays a and b have identical shapes and dtype float32, but only validates dtype and shape equality without checking memory layout, strides, or contiguity that the underlying CUDA kernel expects
If this fails: CUDA kernel may read incorrect memory locations or crash if arrays have non-contiguous strides or unexpected memory layout
examples/ffi/src/jax_ffi_example/cuda_examples.py:foo_fwd
Assumes cudaMemcpyAsync will complete before the CUDA stream is used by subsequent operations, but doesn't add explicit synchronization or check for CUDA errors from the memcpy operation
If this fails: Race conditions where downstream GPU operations read uninitialized data if memcpy hasn't completed, leading to wrong results
examples/ffi/src/jax_ffi_example/gpu_examples.cc:StateExecute
Assumes the input array attribute contains int32_t elements that can be safely summed into an int64_t without overflow, but never checks the array size or validates that sum won't exceed int64_t max value
If this fails: Integer overflow in the sum calculation leads to wrong results when processing large arrays or arrays with large values
examples/ffi/src/jax_ffi_example/cpu_examples.cc:ArrayAttrImpl
Assumes input and output arrays have the same memory layout and that it's safe to read from x and write to y in the same sequential order, but doesn't validate pointer alignment or check for memory overlap
If this fails: Memory corruption or undefined behavior if input and output arrays overlap in memory or have different alignment requirements
examples/ffi/src/jax_ffi_example/rms_norm.cc:ComputeRmsNorm
Assumes the size parameter fits in int64_t and that the variance calculation (sm / size) won't cause numerical instability, but doesn't validate size > 0 or handle the case where all elements are zero
If this fails: Division by zero or numerical overflow when size is 0 or when sum of squares is extremely large, leading to NaN or Inf results
examples/ffi/src/jax_ffi_example/rms_norm.cc:ComputeRmsNorm
Assumes the num parameter is a positive integer that can be used to create a valid numpy array range, but never validates num >= 0 or checks that np.arange(num) won't exhaust memory
If this fails: Memory exhaustion or negative array indices when num is very large or negative, causing crashes or wrong results
examples/ffi/src/jax_ffi_example/cpu_examples.py:array_attr
Assumes the input x is a JAX array with a numeric dtype compatible with the C++ implementation, but never validates that x.dtype maps to a supported C++ type (T in the template)
If this fails: Type mismatch between Python array dtype and C++ template instantiation leads to incorrect memory interpretation and wrong results
examples/ffi/src/jax_ffi_example/rms_norm.py:rms_norm
Assumes jaxlib.mlir.dialects module exists and contains all expected dialect submodules (arith, mhlo, etc.), but uses lazy loading without validating module completeness
If this fails: AttributeError at runtime when trying to access missing dialect modules, breaking MLIR lowering for certain operations
jax/_src/lib/mlir/dialects/__init__.py:lazy_loading
Assumes the State object lifetime extends beyond the StateExecute call and that the memory pointed to by state->value remains valid during the asynchronous CUDA memcpy
If this fails: Use-after-free or reading invalid memory if the State object is destroyed before the CUDA stream completes the memcpy operation
examples/ffi/src/jax_ffi_example/gpu_examples.cc:State
Assumes eps=1e-5 is an appropriate epsilon value for numerical stability across all input scales and dtypes, but doesn't adjust epsilon based on the working precision (float32 vs float64)
If this fails: Numerical instability for very small values with float32 or insufficient precision for float64 computations
examples/ffi/src/jax_ffi_example/rms_norm.py:eps_parameter
See the full structural analysis of jax: the pipeline, data models, and system behavior that put these assumptions in context.
Full analysis of jax-ml/jax →Frequently Asked Questions
What does jax assume that could break in production?
The one most likely to cause trouble: Assumes jaxlib is properly installed and its version is compatible, but only checks version match after successful import — incompatible jaxlib versions that import successfully but have ABI mismatches will cause cryptic runtime errors in XLA compilation If this fails, Silent failures or crashes deep in XLA when jaxlib C++ extensions have incompatible ABIs, making debugging extremely difficult
How many hidden assumptions does jax have?
CodeSea found 13 assumptions jax relies on but never validates, 5 of them critical, spanning Environment, Shape, Resource, Domain, Ordering, Scale, Contract, Temporal. Most are routine — the analysis flags the two or three most likely to actually bite.
What is a hidden assumption?
Something the code depends on but never checks: a data shape, an ordering, an environment condition, a scale limit, or a contract with another service. It holds until the world it runs in changes, then fails silently.