Segment Anything Model (SAM)
The Segment Anything Model (SAM) from Meta AI is a new foundation model for image segmentation, capable of cutting out any object in any image with a single click. It is designed to be a general-purpose segmentation model, applicable to various downstream tasks. The current stable PyPI version is 1.0, with updates generally tied to significant advancements rather than frequent releases.
Warnings
- gotcha Model Checkpoint Download Required. The `pip install segment-anything` command only installs the library code, not the large pre-trained model weights. Users MUST manually download a model checkpoint (e.g., `sam_vit_h_4b8939.pth`) from the official GitHub releases page.
- gotcha Device Management for Performance. By default, SAM models might load to CPU. For significantly faster inference, especially with larger models like ViT-H, explicitly move the model to a CUDA-enabled GPU if available.
- gotcha Image Color Channel Order. If using `OpenCV` (cv2) to load images, it reads them in BGR format by default. SAM models expect images in RGB format. Failing to convert will lead to incorrect or degraded segmentation results.
- deprecated API differences between Research Repo and PyPI Package. The initial research codebase (direct GitHub clone) had some helper functions and class structures that differ from the stable `segment-anything` PyPI package (v1.0+). Relying on old examples from the research repo might lead to `ImportError` or `AttributeError`.
Install
-
pip install segment-anything
Imports
- SamAutomaticMaskGenerator
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
- SamPredictor
from segment_anything import sam_model_registry, SamPredictor
Quickstart
import numpy as np
import torch
import os
# NOTE: You must download a model checkpoint first (e.g., sam_vit_h_4b8939.pth)
# from https://github.com/facebookresearch/segment-anything/releases/tag/v1.0
# For this example, we'll assume a dummy path and model type.
SAM_CHECKPOINT_PATH = os.environ.get('SAM_CHECKPOINT', 'sam_vit_h_4b8939.pth')
MODEL_TYPE = os.environ.get('SAM_MODEL_TYPE', 'vit_h') # e.g., 'vit_h', 'vit_l', 'vit_b'
# Dummy image data (replace with actual image loading, e.g., using OpenCV)
# Assuming a 1024x1024 RGB image for demonstration
image = np.zeros((1024, 1024, 3), dtype=np.uint8)
# Simulate loading a real image:
# import cv2
# image_path = 'path/to/your/image.jpg'
# image = cv2.imread(image_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Important: Convert BGR to RGB
# Check if checkpoint exists
if not os.path.exists(SAM_CHECKPOINT_PATH):
print(f"Warning: Model checkpoint '{SAM_CHECKPOINT_PATH}' not found.\n"+
"Please download it from the official Segment Anything GitHub releases.")
# Exit or provide dummy output for demonstration purposes
exit()
from segment_anything import sam_model_registry, SamPredictor
# Initialize SAM model
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH)
# Set device: 'cuda' for GPU if available, else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam.to(device=device)
print(f"Using device: {device}")
# Create a predictor
predictor = SamPredictor(sam)
predictor.set_image(image)
# Example: Point prompt for a single object
input_point = np.array([[500, 375]]) # Coordinates [x, y]
input_label = np.array([1]) # 1 for foreground, 0 for background
# Predict masks
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
print(f"Generated {len(masks)} masks.")
print(f"Scores: {scores}")
# print(f"First mask shape: {masks[0].shape}, dtype: {masks[0].dtype}")
# The 'masks' array contains boolean masks: True for foreground, False for background