Segment Anything Model (SAM)

1.0 · active · verified Sun Apr 12

The Segment Anything Model (SAM) from Meta AI is a new foundation model for image segmentation, capable of cutting out any object in any image with a single click. It is designed to be a general-purpose segmentation model, applicable to various downstream tasks. The current stable PyPI version is 1.0, with updates generally tied to significant advancements rather than frequent releases.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize the Segment Anything Model (SAM) and use `SamPredictor` for point-based inference. It highlights the necessity of downloading a model checkpoint and correctly setting the device. For automatic mask generation, `SamAutomaticMaskGenerator` would be used instead.

import numpy as np
import torch
import os

# NOTE: You must download a model checkpoint first (e.g., sam_vit_h_4b8939.pth)
# from https://github.com/facebookresearch/segment-anything/releases/tag/v1.0
# For this example, we'll assume a dummy path and model type.
SAM_CHECKPOINT_PATH = os.environ.get('SAM_CHECKPOINT', 'sam_vit_h_4b8939.pth')
MODEL_TYPE = os.environ.get('SAM_MODEL_TYPE', 'vit_h') # e.g., 'vit_h', 'vit_l', 'vit_b'

# Dummy image data (replace with actual image loading, e.g., using OpenCV)
# Assuming a 1024x1024 RGB image for demonstration
image = np.zeros((1024, 1024, 3), dtype=np.uint8)
# Simulate loading a real image:
# import cv2
# image_path = 'path/to/your/image.jpg'
# image = cv2.imread(image_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Important: Convert BGR to RGB

# Check if checkpoint exists
if not os.path.exists(SAM_CHECKPOINT_PATH):
    print(f"Warning: Model checkpoint '{SAM_CHECKPOINT_PATH}' not found.\n"+
          "Please download it from the official Segment Anything GitHub releases.")
    # Exit or provide dummy output for demonstration purposes
    exit()

from segment_anything import sam_model_registry, SamPredictor

# Initialize SAM model
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH)

# Set device: 'cuda' for GPU if available, else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam.to(device=device)
print(f"Using device: {device}")

# Create a predictor
predictor = SamPredictor(sam)
predictor.set_image(image)

# Example: Point prompt for a single object
input_point = np.array([[500, 375]]) # Coordinates [x, y]
input_label = np.array([1])      # 1 for foreground, 0 for background

# Predict masks
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

print(f"Generated {len(masks)} masks.")
print(f"Scores: {scores}")
# print(f"First mask shape: {masks[0].shape}, dtype: {masks[0].dtype}")
# The 'masks' array contains boolean masks: True for foreground, False for background

view raw JSON →