Product Key Memory

raw JSON →
0.3.0 verified Sat May 09 auth: no python

A PyTorch implementation of Product Key Memory (PKM), an external memory module for neural networks with fast nearest-neighbor lookup via product quantization. Current version 0.3.0, requires Python >=3.6. Active development by lucidrains.

pip install product-key-memory
error ImportError: cannot import name 'ProductKeyMemory' from 'product_key_memory'
cause Installed an older version (pre-0.2.0) that used a different module name or structure.
fix
Upgrade to latest: pip install --upgrade product-key-memory
error RuntimeError: Expected tensor to be on the same device, but found at least two devices
cause Model and input tensors on different devices (CPU vs GPU).
fix
Ensure both are on same device: model = model.to('cuda'); x = x.to('cuda')
error TypeError: forward() got an unexpected keyword argument 'return_loss'
cause Using deprecated keyword from older version (v0.1.x).
fix
Remove return_loss; loss is now always computed and returned as second element.
breaking In v0.3.0, the forward method signature changed: previously returned (output, loss, aux_loss), now returns (output, loss). The aux_loss is no longer returned separately.
fix Update code to expect only two return values. If you used aux_loss, it is now included in loss.
gotcha The module expects input shape (batch, seq_len, dim). Common mistake: passing (batch, dim) for single timestep leads to errors.
fix Ensure input is 3D: unsqueeze if necessary.
deprecated The argument `heads` may be deprecated in future versions; the library is moving to a more efficient single-head implementation.
fix Use heads=1 or remove argument; check documentation.

Creates a PKM module, applies it to random input, returns output and auxiliary loss.

import torch
from product_key_memory import ProductKeyMemory

model = ProductKeyMemory(
    dim=512,
    num_keys=512,
    topk=32,
    dim_head=64,
    heads=8,
    use_layernorm=True
)
x = torch.randn(1, 128, 512)
output, loss = model(x, x)
print(output.shape)  # (1, 128, 512)
print(loss.shape)    # (1,)