Lineax: JAX-native Linear Solvers
Lineax provides high-performance, JIT-compilable, and differentiable linear solvers for systems of the form Ax=b, built on JAX and designed to integrate seamlessly with Equinox. It enables defining custom linear operators and choosing various direct or iterative solvers. The current version is 0.1.0, indicating an early stage of development, with a release cadence that tends to follow updates in the broader JAX and Equinox ecosystem.
Common errors
-
ValueError: lu_factor: Input matrix must be non-singular.
cause The linear operator (e.g., `lx.Matrix`) provided to a direct solver is singular, meaning it does not have a unique inverse, or is ill-conditioned.fixEnsure your matrix is non-singular by checking its determinant (`jnp.linalg.det`) or condition number. If the system is inherently singular, consider alternative problem formulations. For ill-conditioned but non-singular systems, iterative solvers or preconditioning may be more robust. -
TypeError: Abstract tracer value encountered where a concrete value was expected.
cause You are attempting to use a JAX array in a Python control flow statement (like `if` or `while`) within a `jax.jit`-compiled function. JAX's tracing mechanism requires concrete Python values for such control flow.fixReplace Python control flow with JAX control flow primitives like `jax.lax.cond`, `jax.lax.while_loop`, or `jax.lax.fori_loop`. Alternatively, ensure that any control flow dependent on JAX array values is moved outside the `jax.jit` decorated function. -
ValueError: Incompatible shapes. Expected output shape (X,) but got (Y,).
cause The shape of the right-hand side vector `b` does not match the expected output shape of the linear operator, or there's a mismatch between the operator's input/output structures.fixCarefully review the `in_structure()` and `out_structure()` methods of your `AbstractLinearOperator` and ensure that the input vector for `lx.solve` has a shape compatible with `operator.out_structure()`.
Warnings
- gotcha Lineax operations are designed for JAX's JIT compilation. Running complex operations without `jax.jit` can result in significant performance penalties due to frequent re-tracing.
- gotcha JAX arrays are immutable. Attempting in-place modification of arrays or using mutable Python objects (like lists) within JIT-compiled functions will lead to errors or unexpected behavior.
- gotcha Lineax solvers, especially direct ones, may fail or produce inaccurate results for singular, ill-conditioned, or poorly scaled linear operators. This is a fundamental limitation of numerical linear algebra.
- gotcha Lineax extensively uses `jaxtyping` for static shape and dtype annotations. While not strictly mandatory to run, ignoring them can make debugging shape/dtype mismatches more challenging.
Install
-
pip install lineax
Imports
- lineax
import lineax as lx
- AbstractLinearOperator
from lineax import AbstractLinearOperator
- Matrix
from lineax.operators import Matrix
- solve
from lineax import solve
- Direct
from lineax.solvers import Direct
Quickstart
import jax
import jax.numpy as jnp
import lineax as lx
# Define a linear operator (e.g., a matrix)
matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
operator = lx.Matrix(matrix)
# Define the right-hand side vector
vector = jnp.array([5.0, 6.0])
# Solve the linear system Ax = b using the default (direct) solver
solution = lx.solve(operator, vector)
print("Solution:", solution)
# Expected output: Solution: [-4. 4.5]