Mamba State-Space Model

raw JSON →
2.3.1 verified Mon Apr 27 auth: no python

Mamba is a state-space model architecture designed for efficient sequence modeling, offering linear-time inference and parallelizable training. Currently at version 2.3.1, it requires Python >=3.9 and is under active development with frequent releases focused on ROCm and CUDA compatibility.

pip install mamba-ssm
error ImportError: cannot import name 'Mamba' from 'mamba_ssm'
cause The package does not expose Mamba directly if installed without the CUDA kernel build (e.g., CPU-only install).
fix
Ensure a CUDA-enabled environment with torch and build the package from source: pip install mamba-ssm --no-binary mamba-ssm
error RuntimeError: CUDA error: no kernel image is available for execution on the device
cause The installed wheel was compiled for a different CUDA compute capability (e.g., sm_80 vs sm_75).
fix
Reinstall from source with matching CUDA architecture: TORCH_CUDA_ARCH_LIST='8.0' pip install mamba-ssm --no-binary mamba-ssm
error AttributeError: module 'mamba_ssm' has no attribute 'Mamba'
cause You imported 'import mamba_ssm' and tried to access mamba_ssm.Mamba, but the class is not exposed at module level.
fix
Use 'from mamba_ssm import Mamba' instead.
breaking Mamba v2.x requires PyTorch >=2.0 and CUDA 11.8+ for GPU support. Older PyTorch versions will fail with missing ops.
fix Upgrade PyTorch to 2.0+ and ensure CUDA toolkit 11.8+ is available.
breaking The causal-conv1d dependency is a separate package that may have ABI incompatibilities with different PyTorch versions. Mixing builds can cause silent incorrect results or crashes.
fix Install causal-conv1d from the same source (PyPI with matching CUDA version) or build from source using the same PyTorch build.
deprecated The 'pscan' parameter in Mamba is deprecated since v2.0 and will be removed in a future release. Setting pscan=True may lead to undefined behavior.
fix Remove the 'pscan' argument or set it to False (default).
gotcha The Mamba model expects input shape (batch, seq_len, d_model). Transposing or using (seq_len, batch, d_model) will not raise an error but will produce incorrect outputs due to dimension mismatch.
fix Ensure input shape is (B, L, D). Use x = x.transpose(0,1) if you have (L, B, D).
gotcha When using torch.compile, the custom CUDA kernels may not be compatible. Expect failures or performance degradation.
fix Avoid torch.compile with Mamba layers, or test thoroughly.
pip install mamba-ssm[causal-conv1d]

Instantiate a Mamba model and run a forward pass on GPU with float16 precision.

import torch
from mamba_ssm import Mamba

batch, seq_len, dim = 2, 128, 16
model = Mamba(
    d_model=dim,
    d_state=16,
    d_conv=4,
    expand_factor=2,
    dt_rank='auto',
    bias=False,
    conv_bias=True,
    pscan=False,
    device='cuda',
    dtype=torch.float16
)
x = torch.randn(batch, seq_len, dim, device='cuda', dtype=torch.float16)
y = model(x)
print(y.shape)
# Expected: (2, 128, 16)