Cut Cross Entropy

25.1.1 · active · verified Tue Apr 14

Cut Cross Entropy provides a highly memory-efficient implementation of the linear-cross-entropy loss function, primarily optimized for large language models and high-throughput inference scenarios. It is part of the vLLM project. The current version is 25.1.1, indicating a rapid development cycle, likely following a date-based or frequent release cadence, designed for NVIDIA GPUs.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `cut_cross_entropy` to calculate the loss. It explicitly checks for CUDA availability, as the library is fundamentally designed for and requires a CUDA-enabled NVIDIA GPU. The example shows both basic usage and an application with `num_total_tokens`, using `float16` for logits as is common for memory-efficient GPU workloads.

import torch
from cut_cross_entropy import cut_cross_entropy

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using CUDA device: {device}")

    # Example: logits (batch_size, vocab_size), labels (batch_size,)
    batch_size = 2
    vocab_size = 4

    # Data often uses float16 for memory efficiency and performance on GPU
    logits = torch.randn(batch_size, vocab_size, device=device, dtype=torch.float16)
    labels = torch.randint(0, vocab_size, (batch_size,), device=device, dtype=torch.int64)

    # Calculate loss
    loss = cut_cross_entropy(logits, labels)
    print(f"Calculated loss: {loss.item():.4f}")

    # Example with num_total_tokens (for distributed/batched scenarios)
    num_total_tokens = torch.tensor([10], device=device, dtype=torch.int64)
    loss_with_tokens = cut_cross_entropy(logits, labels, num_total_tokens)
    print(f"Calculated loss with total tokens: {loss_with_tokens.item():.4f}")
else:
    print("CUDA is not available. This library is designed for NVIDIA GPUs.")
    print("Please ensure you have a CUDA-enabled GPU and the correct PyTorch installation.")

view raw JSON →