CLU (Common Loop Utils)

0.0.12 · active · verified Mon Apr 13

CLU (Common Loop Utils) is a set of opinionated utility libraries for building machine learning training loops in JAX. It provides modules for metrics, parameter overviews, checkpointing, profiling, and data loading. The library is actively maintained by Google, with releases occurring a few times a year, focusing on JAX compatibility and usability improvements.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define and use `clu.metrics` to track custom metrics like accuracy and average loss over multiple training steps in a JAX environment. It shows how to create a `metrics.Collection`, gather outputs from a simulated `train_step`, merge them, and then compute the final aggregate results.

import jax
import jax.numpy as jnp
from clu import metrics

# Define a simple custom metric
class Accuracy(metrics.Metric): 
  num_correct: metrics.Sum.from_output('correct')
  num_total: metrics.Sum.from_output('total')

  def compute(self):
    return self.num_correct / self.num_total

# Simulate a training step
def train_step(params, batch):
  # In a real scenario, this would involve model inference and loss calculation
  predictions = jnp.array([0.8, 0.1, 0.9]) # Example predictions
  labels = jnp.array([1, 0, 1]) # Example true labels
  
  correct = (predictions > 0.5) == labels
  
  return {
      'correct': correct.sum(), 
      'total': correct.size
  }

# Initialize metrics
all_metrics = metrics.Collection.create(
    accuracy=Accuracy,
    loss=metrics.Average.from_output('batch_loss')
)

# Simulate first batch
params_dummy = {'w': jnp.array([1.0])}
batch_dummy = {}

step_outputs = train_step(params_dummy, batch_dummy)
step_outputs['batch_loss'] = jnp.array(0.1) # Example loss

all_metrics = all_metrics.merge(all_metrics.empty()).gather_from_model_output(**step_outputs)

# Simulate another batch
step_outputs_2 = train_step(params_dummy, batch_dummy)
step_outputs_2['batch_loss'] = jnp.array(0.05) # Example loss

all_metrics = all_metrics.merge(all_metrics.gather_from_model_output(**step_outputs_2))

# Compute and print results
results = all_metrics.compute()
print(f"Average Accuracy: {results['accuracy']:.2f}")
print(f"Average Loss: {results['loss']:.2f}")

view raw JSON →