entmax

raw JSON →
1.3 verified Fri May 01 auth: no python

A family of sparse alternatives to softmax, including entmax, sparsemax, normmax, and budget, with corresponding loss functions. Current version: 1.3. Release cadence is irregular, with updates as needed for bug fixes and new activations.

pip install entmax
error ModuleNotFoundError: No module named 'torch'
cause entmax depends on PyTorch but does not declare it as a dependency; install fails if torch is missing.
fix
Install PyTorch first: pip install torch. Then install entmax.
error RuntimeError: [...] torch.cuda.is_available() false
cause entmax operations may be slower on CPU; but the library works on CPU. This error appears if CUDA-specific code is called incorrectly.
fix
Ensure you are using a compatible device. The library supports both CPU and GPU; if you don't need GPU, ignore warnings.
error ImportError: cannot import name 'entmax_loss' from 'entmax'
cause Older versions (pre-1.3) may not include entmax_loss; or import path changed.
fix
Upgrade: pip install --upgrade entmax. Then use from entmax import entmax_loss.
gotcha entmax 1.x requires PyTorch to be installed before installation. Attempting to install entmax without torch will fail. Use conda or pip to install PyTorch first.
fix Install PyTorch first: pip install torch. Then proceed with pip install entmax.
gotcha The bisection-based functions (entmax_bisect, etc.) are slower than the closed-form variants. They are intended for non-standard alpha values (not 1.5 or 2).
fix Use entmax15 or Sparsemax for standard alphas; use bisection only for custom alphas.
deprecated The 'alpha' parameter in some older functions is deprecated; use the specific entmax15 or Sparsemax functions instead.
fix Migrate to entmax15 (alpha=1.5) or Sparsemax (alpha=2).

Apply entmax-1.5 activation and the generic bisection-based entmax function.

import torch
from entmax import entmax15, entmax_bisect

x = torch.randn(2, 5)
y = entmax15(x, dim=-1)
print(y)
z = entmax_bisect(x, alpha=1.5, dim=-1)
print(z)