TransformerLens

raw JSON →
3.1.0 verified Fri May 01 auth: no python

A library for training and analysing transformer models, focused on mechanistic interpretability. Provides tools to probe, edit, and visualise model internals. Current version 3.1.0, released with support for Python >=3.10, <4.0. Active development.

pip install transformer-lens
error ModuleNotFoundError: No module named 'transformer_lens'
cause Installing the package with a hyphen but trying to import with a hyphen, or vice versa.
fix
Install with 'pip install transformer-lens' and import with 'import transformer_lens' (underscore).
error AttributeError: module 'transformer_lens' has no attribute 'utils'
cause The 'utils' module was removed in version 3.0.
fix
Use specific submodules like 'transformer_lens.loading_from_pretrained' for loading utils.
error KeyError: 'blocks.0.hook_mlp' (when accessing cache)
cause The cache key may have a different naming format depending on the model and version.
fix
Print the keys of the cache dictionary to see the exact names: print(cache.keys())
breaking In version 3.0, the 'transformer_lens' package was renamed from 'transformer-lens' (hyphen) to 'transformer_lens' (underscore) for imports. Old imports using 'transformer_lens' (with hyphen) will fail.
fix Change all imports to use underscores: 'import transformer_lens'.
deprecated The old 'utils' module (transformer_lens.utils) has been deprecated in favour of individual submodules. Functions like 'to_numpy' have moved.
fix Check the changelog for the new location of utility functions.
gotcha When using 'model.run_with_cache', the returned cache is a dictionary keyed by layer names, but the tensor dimensions can be counterintuitive (batch, pos, d_model). Ensure you permute correctly for visualisation.
fix Cache tensors are of shape (batch, pos, d_model). Use cache['blocks.0.hook_mlp'].shape to verify.

Load a pretrained model and run a forward pass.

import torch
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("tiny-stories-1L-21M")
prompts = "The capital of France is"
logits = model(prompts)
token = logits.argmax(dim=-1).squeeze()
print(model.tokenizer.decode(token))