{"id":8285,"library":"lineax","title":"Lineax: JAX-native Linear Solvers","description":"Lineax provides high-performance, JIT-compilable, and differentiable linear solvers for systems of the form Ax=b, built on JAX and designed to integrate seamlessly with Equinox. It enables defining custom linear operators and choosing various direct or iterative solvers. The current version is 0.1.0, indicating an early stage of development, with a release cadence that tends to follow updates in the broader JAX and Equinox ecosystem.","status":"active","version":"0.1.0","language":"en","source_language":"en","source_url":"https://github.com/google/lineax","tags":["jax","equinox","linear algebra","solver","numerical","machine learning"],"install":[{"cmd":"pip install lineax","lang":"bash","label":"Install latest version"}],"dependencies":[{"reason":"Core dependency for numerical computation and automatic differentiation.","package":"jax"},{"reason":"Primary integration target for defining differentiable models and operators.","package":"equinox"},{"reason":"Used for static type checking and shape annotations, enhancing code robustness.","package":"jaxtyping","optional":true}],"imports":[{"symbol":"lineax","correct":"import lineax as lx"},{"note":"Base class for defining custom linear operators.","symbol":"AbstractLinearOperator","correct":"from lineax import AbstractLinearOperator"},{"note":"Commonly used directly as `lx.Matrix` after `import lineax as lx`.","symbol":"Matrix","correct":"from lineax.operators import Matrix"},{"note":"The primary function for solving linear systems; typically used as `lx.solve`.","symbol":"solve","correct":"from lineax import solve"},{"note":"One of several available solvers, often used as `lx.Direct()`.","symbol":"Direct","correct":"from lineax.solvers import Direct"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport lineax as lx\n\n# Define a linear operator (e.g., a matrix)\nmatrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])\noperator = lx.Matrix(matrix)\n\n# Define the right-hand side vector\nvector = jnp.array([5.0, 6.0])\n\n# Solve the linear system Ax = b using the default (direct) solver\nsolution = lx.solve(operator, vector)\n\nprint(\"Solution:\", solution)\n# Expected output: Solution: [-4.   4.5]","lang":"python","description":"This quickstart demonstrates how to define a simple `lx.Matrix` operator and use `lx.solve` to find the solution to a linear system Ax=b. This example leverages JAX arrays and lineax's core functionality."},"warnings":[{"fix":"Always wrap your core computation logic (especially loops or repeated operations) with `jax.jit` for optimal performance. Remember that `jax.jit` requires pure functions and static shapes.","message":"Lineax operations are designed for JAX's JIT compilation. Running complex operations without `jax.jit` can result in significant performance penalties due to frequent re-tracing.","severity":"gotcha","affected_versions":"0.1.0+"},{"fix":"Embrace the functional programming paradigm: all operations on JAX arrays return new arrays. Use JAX primitives like `jax.vmap`, `jax.lax.scan`, `jax.lax.fori_loop`, or `jax.lax.cond` for control flow and transformations.","message":"JAX arrays are immutable. Attempting in-place modification of arrays or using mutable Python objects (like lists) within JIT-compiled functions will lead to errors or unexpected behavior.","severity":"gotcha","affected_versions":"0.1.0+"},{"fix":"Before solving, verify the properties of your operator (e.g., condition number for `lx.Matrix`). For ill-conditioned systems, consider using iterative solvers (e.g., GMRES) or implementing appropriate preconditioning techniques.","message":"Lineax solvers, especially direct ones, may fail or produce inaccurate results for singular, ill-conditioned, or poorly scaled linear operators. This is a fundamental limitation of numerical linear algebra.","severity":"gotcha","affected_versions":"0.1.0+"},{"fix":"Leverage `jaxtyping` for type annotations in your functions to catch dimension and type errors early during development. This improves code readability and robustness, especially in complex JAX applications.","message":"Lineax extensively uses `jaxtyping` for static shape and dtype annotations. While not strictly mandatory to run, ignoring them can make debugging shape/dtype mismatches more challenging.","severity":"gotcha","affected_versions":"0.1.0+"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Ensure your matrix is non-singular by checking its determinant (`jnp.linalg.det`) or condition number. If the system is inherently singular, consider alternative problem formulations. For ill-conditioned but non-singular systems, iterative solvers or preconditioning may be more robust.","cause":"The linear operator (e.g., `lx.Matrix`) provided to a direct solver is singular, meaning it does not have a unique inverse, or is ill-conditioned.","error":"ValueError: lu_factor: Input matrix must be non-singular."},{"fix":"Replace Python control flow with JAX control flow primitives like `jax.lax.cond`, `jax.lax.while_loop`, or `jax.lax.fori_loop`. Alternatively, ensure that any control flow dependent on JAX array values is moved outside the `jax.jit` decorated function.","cause":"You are attempting to use a JAX array in a Python control flow statement (like `if` or `while`) within a `jax.jit`-compiled function. JAX's tracing mechanism requires concrete Python values for such control flow.","error":"TypeError: Abstract tracer value encountered where a concrete value was expected."},{"fix":"Carefully review the `in_structure()` and `out_structure()` methods of your `AbstractLinearOperator` and ensure that the input vector for `lx.solve` has a shape compatible with `operator.out_structure()`.","cause":"The shape of the right-hand side vector `b` does not match the expected output shape of the linear operator, or there's a mismatch between the operator's input/output structures.","error":"ValueError: Incompatible shapes. Expected output shape (X,) but got (Y,)."}]}