CLU (Common Loop Utils)
CLU (Common Loop Utils) is a set of opinionated utility libraries for building machine learning training loops in JAX. It provides modules for metrics, parameter overviews, checkpointing, profiling, and data loading. The library is actively maintained by Google, with releases occurring a few times a year, focusing on JAX compatibility and usability improvements.
Warnings
- breaking CLU versions v0.0.10 and newer require Python 3.10 or higher. Older versions might support Python 3.7 or 3.8, but support was progressively dropped.
- breaking The `asynclib` module was moved from `clu.internal.asynclib` to `clu.asynclib`. Direct imports from `clu.internal` will fail.
- gotcha The `clu.profile` module switched its backend from TensorFlow Profiler to JAX Profiler. If you were relying on `clu.profile` to integrate with TensorFlow profiling tools, this will no longer work as expected.
- gotcha CLU v0.0.12 internally updated its usage of `jax.tree_map` to `jax.tree_util.tree_map` due to JAX deprecation. While this is an internal change, users interacting with CLU's tree structures or custom JAX types might need to be aware of JAX's evolving tree utilities.
Install
-
pip install clu
Imports
- MetricWriter
from clu import metric_writers
- Average
from clu import metrics
- ParameterOverview
from clu import parameter_overview
- Profile
from clu import profile
- AsyncPool
from clu import asynclib
Quickstart
import jax
import jax.numpy as jnp
from clu import metrics
# Define a simple custom metric
class Accuracy(metrics.Metric):
num_correct: metrics.Sum.from_output('correct')
num_total: metrics.Sum.from_output('total')
def compute(self):
return self.num_correct / self.num_total
# Simulate a training step
def train_step(params, batch):
# In a real scenario, this would involve model inference and loss calculation
predictions = jnp.array([0.8, 0.1, 0.9]) # Example predictions
labels = jnp.array([1, 0, 1]) # Example true labels
correct = (predictions > 0.5) == labels
return {
'correct': correct.sum(),
'total': correct.size
}
# Initialize metrics
all_metrics = metrics.Collection.create(
accuracy=Accuracy,
loss=metrics.Average.from_output('batch_loss')
)
# Simulate first batch
params_dummy = {'w': jnp.array([1.0])}
batch_dummy = {}
step_outputs = train_step(params_dummy, batch_dummy)
step_outputs['batch_loss'] = jnp.array(0.1) # Example loss
all_metrics = all_metrics.merge(all_metrics.empty()).gather_from_model_output(**step_outputs)
# Simulate another batch
step_outputs_2 = train_step(params_dummy, batch_dummy)
step_outputs_2['batch_loss'] = jnp.array(0.05) # Example loss
all_metrics = all_metrics.merge(all_metrics.gather_from_model_output(**step_outputs_2))
# Compute and print results
results = all_metrics.compute()
print(f"Average Accuracy: {results['accuracy']:.2f}")
print(f"Average Loss: {results['loss']:.2f}")