Product Key Memory
raw JSON → 0.3.0 verified Sat May 09 auth: no python
A PyTorch implementation of Product Key Memory (PKM), an external memory module for neural networks with fast nearest-neighbor lookup via product quantization. Current version 0.3.0, requires Python >=3.6. Active development by lucidrains.
pip install product-key-memory Common errors
error ImportError: cannot import name 'ProductKeyMemory' from 'product_key_memory' ↓
cause Installed an older version (pre-0.2.0) that used a different module name or structure.
fix
Upgrade to latest: pip install --upgrade product-key-memory
error RuntimeError: Expected tensor to be on the same device, but found at least two devices ↓
cause Model and input tensors on different devices (CPU vs GPU).
fix
Ensure both are on same device: model = model.to('cuda'); x = x.to('cuda')
error TypeError: forward() got an unexpected keyword argument 'return_loss' ↓
cause Using deprecated keyword from older version (v0.1.x).
fix
Remove return_loss; loss is now always computed and returned as second element.
Warnings
breaking In v0.3.0, the forward method signature changed: previously returned (output, loss, aux_loss), now returns (output, loss). The aux_loss is no longer returned separately. ↓
fix Update code to expect only two return values. If you used aux_loss, it is now included in loss.
gotcha The module expects input shape (batch, seq_len, dim). Common mistake: passing (batch, dim) for single timestep leads to errors. ↓
fix Ensure input is 3D: unsqueeze if necessary.
deprecated The argument `heads` may be deprecated in future versions; the library is moving to a more efficient single-head implementation. ↓
fix Use heads=1 or remove argument; check documentation.
Imports
- ProductKeyMemory
from product_key_memory import ProductKeyMemory
Quickstart
import torch
from product_key_memory import ProductKeyMemory
model = ProductKeyMemory(
dim=512,
num_keys=512,
topk=32,
dim_head=64,
heads=8,
use_layernorm=True
)
x = torch.randn(1, 128, 512)
output, loss = model(x, x)
print(output.shape) # (1, 128, 512)
print(loss.shape) # (1,)