Cut Cross Entropy
Cut Cross Entropy provides a highly memory-efficient implementation of the linear-cross-entropy loss function, primarily optimized for large language models and high-throughput inference scenarios. It is part of the vLLM project. The current version is 25.1.1, indicating a rapid development cycle, likely following a date-based or frequent release cadence, designed for NVIDIA GPUs.
Warnings
- gotcha This library is exclusively designed for and requires a CUDA-enabled NVIDIA GPU. It will not function on CPU-only systems, even if PyTorch is installed.
- gotcha The PyPI package name is `cut-cross-entropy` (using hyphens), but the Python module you import is `cut_cross_entropy` (using underscores). Incorrect module import paths are a common mistake.
- gotcha The library is optimized for memory efficiency and often used with `torch.float16` (half-precision). While `float32` might work, the primary performance and memory benefits are realized with `float16`, and using `float32` could potentially negate some of the library's advantages.
- gotcha The versioning (e.g., 25.1.1) suggests a rapid development pace, likely tied to the `vllm` project's releases. This can imply more frequent API changes compared to libraries adhering to strict semantic versioning.
Install
-
pip install cut-cross-entropy torch>=2.0.0
Imports
- cut_cross_entropy
from cut_cross_entropy import cut_cross_entropy
- cut_cross_entropy_reference
from cut_cross_entropy import cut_cross_entropy_reference
Quickstart
import torch
from cut_cross_entropy import cut_cross_entropy
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using CUDA device: {device}")
# Example: logits (batch_size, vocab_size), labels (batch_size,)
batch_size = 2
vocab_size = 4
# Data often uses float16 for memory efficiency and performance on GPU
logits = torch.randn(batch_size, vocab_size, device=device, dtype=torch.float16)
labels = torch.randint(0, vocab_size, (batch_size,), device=device, dtype=torch.int64)
# Calculate loss
loss = cut_cross_entropy(logits, labels)
print(f"Calculated loss: {loss.item():.4f}")
# Example with num_total_tokens (for distributed/batched scenarios)
num_total_tokens = torch.tensor([10], device=device, dtype=torch.int64)
loss_with_tokens = cut_cross_entropy(logits, labels, num_total_tokens)
print(f"Calculated loss with total tokens: {loss_with_tokens.item():.4f}")
else:
print("CUDA is not available. This library is designed for NVIDIA GPUs.")
print("Please ensure you have a CUDA-enabled GPU and the correct PyTorch installation.")