KLUJAX

raw JSON →
0.5.0 verified Mon Apr 27 auth: no python

A JAX-compatible KLU sparse direct solver for linear systems (LU factorization), useful for circuit simulation and power systems. Version 0.5.0; pre-1.0, breaking changes possible. Cadence: irregular.

pip install klujax
error ModuleNotFoundError: No module named 'klujax'
cause Package not installed or installed in wrong environment.
fix
Run pip install klujax and verify with python -c 'import klujax; print(klujax.__version__)'.
error TypeError: Expected jnp.ndarray, got list
cause Passed Python lists instead of JAX arrays.
fix
Convert inputs to JAX arrays: row = jnp.array(row_list, dtype=jnp.int32).
error AssertionError: nonsquare matrix
cause Solver currently requires square matrices (M=N).
fix
Ensure the matrix is square. For rectangular systems, pad or use least squares.
breaking Pre-1.0 API may change without notice. The function signature or return type may change in future releases.
fix Pin version with `klujax==0.5.0` and monitor releases.
gotcha klujax_solve expects row, col, data as JAX int32/int64 and float32/float64 arrays. Using Python lists or incorrect dtypes may cause silent errors or type promotion issues.
fix Ensure inputs are JAX arrays with correct dtypes: `row=jnp.array(..., dtype=jnp.int32)`, `data=jnp.array(..., dtype=jnp.float64)`.
gotcha The solver is not JIT-compatible by default; current implementation uses Python loops. JIT compilation may fail or produce incorrect results.
fix Do not wrap klujax_solve inside `jax.jit`. Use outside JIT context.

Solve a sparse linear system using KLU via JAX. Inputs are COO format arrays (row, col, data) and right-hand side b.

import jax.numpy as jnp
from klujax import klujax_solve

# Define a sparse matrix in COO format
row = jnp.array([0, 1, 0, 1], dtype=jnp.int32)
col = jnp.array([0, 0, 1, 1], dtype=jnp.int32)
data = jnp.array([2.0, 1.0, 1.0, 3.0], dtype=jnp.float64)
b = jnp.array([1.0, 2.0], dtype=jnp.float64)

# Solve A x = b
x = klujax_solve(row, col, data, b)
print(x)