SAHI (Slicing Aided Hyper Inference)
SAHI (Slicing Aided Hyper Inference) is a lightweight Python library designed to improve object detection and instance segmentation performance, especially for small objects in large or high-resolution images. It achieves this by dividing images into smaller overlapping slices, running inference on each slice, and then intelligently merging the predictions. Currently at version 0.11.36, SAHI has a frequent release cadence, often issuing patch releases to address bugs and introduce minor enhancements.
Warnings
- breaking The `confidence_threshold` parameter in `AutoDetectionModel.from_pretrained` changed its behavior or effect. Previously, it might have only filtered detections. Newer versions (around 0.11.27 and later discussions) suggest it can also influence the bounding box size or shape, not just filter by score. Always verify expected behavior for a given SAHI version.
- gotcha The `BoundingBox` and `Category` objects were made immutable in versions 0.11.29 and 0.11.31 respectively. Direct modification of their attributes will now raise an error.
- gotcha When working in multi-GPU environments, especially with subprocesses or certain frameworks (like Detectron2), models might default to loading on 'cuda:0' causing imbalanced GPU utilization. This was specifically addressed in 0.11.34 for subprocesses.
- gotcha SAHI significantly increases inference time due to processing multiple slices. It is generally not recommended for real-time applications where latency is critical.
- gotcha Some users have reported degraded performance (fewer detections, lower confidence) when applying SAHI to models that already perform well on an original dataset without slicing.
Install
-
pip install sahi -
pip install sahi[yolov5] -
pip install sahi[ultralytics] -
pip install sahi[mmdet] -
pip install sahi[detectron2] -
pip install sahi[huggingface] -
pip install sahi[torchvision]
Imports
- AutoDetectionModel
from sahi.models import AutoDetectionModel
from sahi import AutoDetectionModel
- get_sliced_prediction
from sahi.predict import get_sliced_prediction
- read_image_as_pil
from sahi.utils.cv import read_image_as_pil
- download_from_url
from sahi.utils.file import download_from_url
Quickstart
import os
import torch
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import read_image
from sahi.utils.file import download_from_url
# Download a sample image
image_url = 'https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg'
image_path = 'small-vehicles1.jpeg'
download_from_url(image_url, image_path)
# Download a YOLOv8s model (requires ultralytics installed: pip install ultralytics)
model_path = 'yolov8s.pt'
# This utility helps download; in a real scenario, you might have your own model.
if not os.path.exists(model_path):
# You would typically download a model or use an existing path
# For this example, we'll try to use a common Ultralytics model.
# For a real quickstart, ensure 'ultralytics' is installed and `yolov8s.pt` is available.
print(f"Please ensure '{model_path}' is available or install 'ultralytics' and download it.")
# Placeholder for actual download if ultralytics is installed
# from ultralytics import YOLO
# model = YOLO('yolov8s.pt') # This would download it if not present
# Then you would pass model.model.pt for model_path or the YOLO object directly to AutoDetectionModel
# Fallback or specific model path if `yolov8s.pt` is not handled by AutoDetectionModel without explicit ultralytics import
# For simplicity, assuming a yolov8s.pt is present or can be loaded by AutoDetectionModel
# Initialize the detection model
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
detection_model = AutoDetectionModel.from_pretrained(
model_type='ultralytics', # Or 'yolov5', 'mmdet', 'huggingface', 'torchvision', etc.
model_path=model_path, # Path to your pretrained model weights
confidence_threshold=0.3,
device=device
)
# Perform sliced inference
result = get_sliced_prediction(
read_image(image_path),
detection_model,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2
)
# Print detection results
print(f"Detected {len(result.object_prediction_list)} objects.")
for i, prediction in enumerate(result.object_prediction_list):
print(f" Detection {i+1}: Class={prediction.category.name}, Confidence={prediction.score.value:.3f}")
# Export visuals (optional, requires opencv-python-headless or opencv-python)
output_dir = './sahi_output'
os.makedirs(output_dir, exist_ok=True)
result.export_visuals(export_dir=output_dir, file_name='prediction_visual.png')
print(f"Visualizations saved to {output_dir}/prediction_visual.png")