I've given it a spin locally by following the online readme. My local notebook sets everything up via:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
sam = sam_model_registry["default"]("sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
Then, I pass the predictor an image.
image = cv2.imread("../cat.png")
predictor.set_image(image)
On my 12 CPU intel NUC this takes 17 seconds. Which could be do-able when preparing the data in batch but it is prohibitive to do in real time.
Next, I mimic a mousepress by giving it coordinates.
input_point = np.array([[200, 175]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
Notice the star near the neck of the cat? We can pretend a mouse-click happened there.

Then, we can ask the model to make a prediction.
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
Here's the prediction that it makes with the highest confidence score.
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()

From a quick demo this indeed seems pretty impressive but there are a few bottlenecks to get this to work in Prodigy in the short term.
- Prodigy's image recipes don't allow for interaction with a backend at the moment. This is certainly something I'd love to revisit for Prodigy v2, but right now that means if you'd have to spin up a prediction server with a custom front-end in order to get this to work inside of Prodigy.
- The compute power that's needed here is prohibitively slow unless you have a good GPU. The site mentions it takes 0.15s on a GPU to prep an image. But without it; 17s seems reasonable to assume.
- I might need to double-check the license, as well as the training data, to fully understand when you're able to apply this on a custom dataset. The Github repository mentions Apache 2, but I'd like to double check the paper to make sure that the model weights have a permissive license.
An idea
The website suggests that there are other workflows possible too. For example; you seem to be able to use bounding boxes to allow SAM to turn that into a mask of the item that you might be interested in. If you have images annotated with a bounding box already then I can see how this could be helpful in turning those into masks. And you could do a two-step annotation scheme where the first step is drawing boxes and the second step is confirming the mask.
The website also suggests that you're able to provide text, like below, to select the items in a photo.
I haven't been able to find this API in their codebase though, but if it were available I can see how we might turn that into an easy recipe.