Lineax: JAX-native Linear Solvers

0.1.0 · active · verified Thu Apr 16

Lineax provides high-performance, JIT-compilable, and differentiable linear solvers for systems of the form Ax=b, built on JAX and designed to integrate seamlessly with Equinox. It enables defining custom linear operators and choosing various direct or iterative solvers. The current version is 0.1.0, indicating an early stage of development, with a release cadence that tends to follow updates in the broader JAX and Equinox ecosystem.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a simple `lx.Matrix` operator and use `lx.solve` to find the solution to a linear system Ax=b. This example leverages JAX arrays and lineax's core functionality.

import jax
import jax.numpy as jnp
import lineax as lx

# Define a linear operator (e.g., a matrix)
matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
operator = lx.Matrix(matrix)

# Define the right-hand side vector
vector = jnp.array([5.0, 6.0])

# Solve the linear system Ax = b using the default (direct) solver
solution = lx.solve(operator, vector)

print("Solution:", solution)
# Expected output: Solution: [-4.   4.5]

view raw JSON →