
Linear solvers in JAX and Equinox.

Lineax is a JAX library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)

Features include:

  • PyTree-valued matrices and vectors;
  • General linear operators for Jacobians, transposes, etc.;
  • Efficient linear least squares (e.g. QR solvers);
  • Numerically stable gradients through linear least squares;
  • Support for structured (e.g. symmetric) matrices;
  • Improved compilation times;
  • Improved runtime of some algorithms;
  • Support for both real-valued and complex-valued inputs;
  • All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc.


pip install lineax

Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.11.0+.


Available at

Quick examples

Lineax can solve a least squares problem with an explicit matrix operator:

import jax.random as jr
import lineax as lx

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())

or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve:

import jax
import lineax as lx

key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))

def quadratic_fn(y, args):
  return jax.numpy.sum((y - 1)**2)

gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value


If you found this library to be useful in academic work, then please cite: (arXiv link)

    title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},
    author={Jason Rader and Terry Lyons and Patrick Kidger},
        AI for science workshop at Neural Information Processing Systems 2023,

(Also consider starring the project on GitHub.)

See also: other libraries in the JAX ecosystem

Always useful Equinox: neural networks and everything not already in core JAX! jaxtyping: type annotations for shape/dtype of arrays.

Deep learning Optax: first-order gradient (SGD, Adam, ...) optimisers. Orbax: checkpointing (async/multi-host/multi-device). Levanter: scalable+reliable training of foundation models (e.g. LLMs).

Scientific computing Diffrax: numerical differential equation solvers. Optimistix: root finding, minimisation, fixed points, and least squares. BlackJAX: probabilistic+Bayesian sampling. sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent. PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX Awesome JAX: a longer list of other JAX projects.