{"id":5155,"library":"clu","title":"CLU (Common Loop Utils)","description":"CLU (Common Loop Utils) is a set of opinionated utility libraries for building machine learning training loops in JAX. It provides modules for metrics, parameter overviews, checkpointing, profiling, and data loading. The library is actively maintained by Google, with releases occurring a few times a year, focusing on JAX compatibility and usability improvements.","status":"active","version":"0.0.12","language":"en","source_language":"en","source_url":"https://github.com/google/CommonLoopUtils","tags":["jax","machine-learning","training-loops","metrics","profiling","checkpointing"],"install":[{"cmd":"pip install clu","lang":"bash","label":"Latest stable version"}],"dependencies":[{"reason":"Core dependency for JAX-based ML workflows.","package":"jax","optional":false},{"reason":"Commonly used for data loading and input pipelines, especially with `clu.deterministic_data`.","package":"tensorflow-datasets","optional":true}],"imports":[{"symbol":"MetricWriter","correct":"from clu import metric_writers"},{"symbol":"Average","correct":"from clu import metrics"},{"symbol":"ParameterOverview","correct":"from clu import parameter_overview"},{"symbol":"Profile","correct":"from clu import profile"},{"note":"The `asynclib` module was moved out of `clu.internal` in v0.0.7.","wrong":"from clu.internal import asynclib","symbol":"AsyncPool","correct":"from clu import asynclib"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nfrom clu import metrics\n\n# Define a simple custom metric\nclass Accuracy(metrics.Metric): \n  num_correct: metrics.Sum.from_output('correct')\n  num_total: metrics.Sum.from_output('total')\n\n  def compute(self):\n    return self.num_correct / self.num_total\n\n# Simulate a training step\ndef train_step(params, batch):\n  # In a real scenario, this would involve model inference and loss calculation\n  predictions = jnp.array([0.8, 0.1, 0.9]) # Example predictions\n  labels = jnp.array([1, 0, 1]) # Example true labels\n  \n  correct = (predictions > 0.5) == labels\n  \n  return {\n      'correct': correct.sum(), \n      'total': correct.size\n  }\n\n# Initialize metrics\nall_metrics = metrics.Collection.create(\n    accuracy=Accuracy,\n    loss=metrics.Average.from_output('batch_loss')\n)\n\n# Simulate first batch\nparams_dummy = {'w': jnp.array([1.0])}\nbatch_dummy = {}\n\nstep_outputs = train_step(params_dummy, batch_dummy)\nstep_outputs['batch_loss'] = jnp.array(0.1) # Example loss\n\nall_metrics = all_metrics.merge(all_metrics.empty()).gather_from_model_output(**step_outputs)\n\n# Simulate another batch\nstep_outputs_2 = train_step(params_dummy, batch_dummy)\nstep_outputs_2['batch_loss'] = jnp.array(0.05) # Example loss\n\nall_metrics = all_metrics.merge(all_metrics.gather_from_model_output(**step_outputs_2))\n\n# Compute and print results\nresults = all_metrics.compute()\nprint(f\"Average Accuracy: {results['accuracy']:.2f}\")\nprint(f\"Average Loss: {results['loss']:.2f}\")","lang":"python","description":"This quickstart demonstrates how to define and use `clu.metrics` to track custom metrics like accuracy and average loss over multiple training steps in a JAX environment. It shows how to create a `metrics.Collection`, gather outputs from a simulated `train_step`, merge them, and then compute the final aggregate results."},"warnings":[{"fix":"Ensure your Python environment is at least 3.10. If you need to use an older Python version, pin CLU to a version prior to 0.0.10 (e.g., `pip install 'clu<0.0.10'`).","message":"CLU versions v0.0.10 and newer require Python 3.10 or higher. Older versions might support Python 3.7 or 3.8, but support was progressively dropped.","severity":"breaking","affected_versions":">=0.0.10"},{"fix":"Update your import statements from `from clu.internal import asynclib` to `from clu import asynclib`.","message":"The `asynclib` module was moved from `clu.internal.asynclib` to `clu.asynclib`. Direct imports from `clu.internal` will fail.","severity":"breaking","affected_versions":">=0.0.7"},{"fix":"Adapt your profiling setup to use JAX profiling tools and visualizations. Ensure `jaxlib` is installed with profiling support if needed.","message":"The `clu.profile` module switched its backend from TensorFlow Profiler to JAX Profiler. If you were relying on `clu.profile` to integrate with TensorFlow profiling tools, this will no longer work as expected.","severity":"gotcha","affected_versions":">=0.0.3"},{"fix":"Ensure your JAX installation is up-to-date and be mindful of JAX's `tree_util` module for tree manipulation functions.","message":"CLU v0.0.12 internally updated its usage of `jax.tree_map` to `jax.tree_util.tree_map` due to JAX deprecation. While this is an internal change, users interacting with CLU's tree structures or custom JAX types might need to be aware of JAX's evolving tree utilities.","severity":"gotcha","affected_versions":"0.0.12"}],"env_vars":null,"last_verified":"2026-04-13T00:00:00.000Z","next_check":"2026-07-12T00:00:00.000Z"}