EfficientDet for PyTorch

0.4.1 · active · verified Tue Apr 14

effdet is a PyTorch implementation of the EfficientDet object detection model. It aims to faithfully reproduce the original TensorFlow models while providing PyTorch flexibility. The library is currently at version 0.4.1 and has a steady release cadence, often aligning with updates to its `timm` dependency and PyTorch versions, focusing on performance and accuracy.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to load a pre-trained EfficientDet model, prepare a dummy image, and perform inference to detect objects. It utilizes the high-level `create_model` function and standard PyTorch image transformations.

import torch
from effdet import create_model
from effdet.data import resolve_input_config
from torchvision import transforms
from PIL import Image
import os

# Create a dummy image for a runnable example
dummy_image_path = "dummy_image.png"
if not os.path.exists(dummy_image_path):
    img = Image.new('RGB', (640, 640), color = 'red')
    img.save(dummy_image_path)

# 1. Load a pre-trained EfficientDet model
# Use 'tf_efficientdet_d0' for a small, fast model.
# bench_task='predict' is crucial for inference mode.
model_name = 'tf_efficientdet_d0'
model = create_model(model_name, pretrained=True, bench_task='predict')
model.eval()

# 2. Prepare the image for inference
img = Image.open(dummy_image_path).convert('RGB')

# Resolve input configuration from the model's pretrained_cfg
input_config = resolve_input_config(model.pretrained_cfg)

# Define image transformation pipeline
transform = transforms.Compose([
    transforms.Resize(input_config['input_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=input_config['mean'], std=input_config['std'])
])

# Apply transformations and add a batch dimension
input_tensor = transform(img).unsqueeze(0)

# 3. Perform inference
with torch.no_grad():
    output = model(input_tensor)

# The output format is typically [x1, y1, x2, y2, score, class]
# Print top 5 detected objects (if any)
if output.numel() > 0:
    print(f"Detected objects (top 5, if available):\n{output[0][:5]}")
else:
    print("No objects detected.")

# Clean up dummy image
os.remove(dummy_image_path)

view raw JSON →