ViT PyTorch

raw JSON →
1.20.4 verified Sat May 09 auth: no python

A PyTorch implementation of the Vision Transformer (ViT) and related vision transformer architectures. Current version 1.20.4, actively maintained with frequent releases.

pip install vit-pytorch
error ImportError: cannot import name 'ViT' from 'vit_pytorch'
cause Using an older version where the API was different or the package was not properly installed.
fix
Upgrade to the latest version with 'pip install --upgrade vit-pytorch' and use 'from vit_pytorch import ViT'.
error ModuleNotFoundError: No module named 'vit_pytorch'
cause Installed the wrong package name ('vit-pytorch' vs 'vit_pytorch') or not installed at all.
fix
Install with 'pip install vit-pytorch' (hyphen) and import with 'import vit_pytorch' (underscore).
breaking In version 1.0.0, the package was renamed from 'vit_pytorch' to 'vit-pytorch' on PyPI, but the import remains 'vit_pytorch'. Ensure you install with 'vit-pytorch' but import with 'vit_pytorch'.
fix Use 'pip install vit-pytorch' and 'from vit_pytorch import ViT'.
gotcha Common mistake: using the wrong import path. Many users try 'from vit_pytorch.vit import ViT' which fails. The correct import is 'from vit_pytorch import ViT'.
fix Use 'from vit_pytorch import ViT'.

Basic instantiation and forward pass of a Vision Transformer.

import torch
from vit_pytorch import ViT

v = ViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=1024,
    depth=6,
    heads=16,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)

img = torch.randn(1, 3, 256, 256)
preds = v(img)  # (1, 1000)