Gram-Newton-Schulz

raw JSON →
0.1.4 verified Sat May 09 auth: no python

A fast implementation of the Newton-Schulz algorithm for computing matrix square roots and Gram matrix inverses, with support for JIT-compiled kernels via Quack. Current version 0.1.4, requires Python >=3.10. Released under the Dao-AILab organization, with occasional updates.

pip install gram-newton-schulz
error ModuleNotFoundError: No module named 'gram_newton_schulz'
cause Library not installed, or installed with wrong name.
fix
Run pip install gram-newton-schulz. The correct import string is from gram_newton_schulz import ....
error RuntimeError: Expected all tensors to be on the same device, but found at least two devices
cause Solver is initialized on CPU while input tensor is on GPU, or vice versa.
fix
Ensure both the solver (if it holds state) and input tensors are on the same device. Use solver = GramNewtonSchulz().to(device).
gotcha Input matrix must be symmetric positive-definite; non-SPD matrices may cause convergence failure or incorrect results.
fix Ensure your matrix is symmetric and has positive eigenvalues. Consider adding a small regularization term like `A + 1e-6 * torch.eye(A.shape[0])`.
gotcha Torch.compile is required for performance; without it, fallback implementation may be very slow or unsupported.
fix Enable torch.compile by setting `torch.compile(solver.sqrt)` or using the `compile_kwargs` argument.
deprecated The `StandardNewtonSchulz` class was merged into `GramNewtonSchulz` in v0.1.0. Old imports will break.
fix Replace `from gram_newton_schulz import StandardNewtonSchulz` with `from gram_newton_schulz import GramNewtonSchulz`.

Compute matrix square root using Gram-Newton-Schulz algorithm.

import torch
from gram_newton_schulz import GramNewtonSchulz

# Create a symmetric positive-definite matrix
A = torch.randn(4, 4, device='cuda' if torch.cuda.is_available() else 'cpu')
A = A @ A.T

# Initialize the solver with default settings
solver = GramNewtonSchulz()

# Compute the matrix square root (X such that X @ X ≈ A)
X = solver.sqrt(A)
print(X)