EfficientNet (PyTorch)

0.7.1 · maintenance · verified Fri Apr 17

efficientnet-pytorch is a PyTorch implementation of the EfficientNet model series (B0-B7), offering highly optimized convolutional neural networks for image classification. As of version 0.7.1, it provides access to pretrained models for various EfficientNet scales, including those trained with AdvProp. The library has seen limited updates since its last major release in 2020, but remains functional and widely used for its robust performance.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to load a pretrained EfficientNet-B0 model, prepare an input image using standard ImageNet preprocessing, and perform a forward pass for inference. It includes handling for GPU availability and sets the model to evaluation mode.

import torch
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from PIL import Image

# 1. Load a pretrained EfficientNet model
model = EfficientNet.from_pretrained('efficientnet-b0')
model.eval() # Set model to evaluation mode

# 2. Define standard ImageNet preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 3. Create a dummy image (replace with actual image loading)
# For a real scenario, use Image.open('path/to/image.jpg').convert('RGB')
img = Image.new('RGB', (256, 256), color = 'red')

# 4. Preprocess the image and add batch dimension
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# 5. Move input to the appropriate device (CPU or GPU)
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

# 6. Perform inference
with torch.no_grad():
    output = model(input_batch)

# The output 'output' contains the logits for the classes
print(f"Output logits shape: {output.shape}")
# Example: get predicted class
# _, predicted_idx = torch.max(output, 1)
# print(f"Predicted class index: {predicted_idx.item()}")

view raw JSON →