Segment Anything Model (SAM) Integration with Prodigy

Hi,

I am working on an image segmentation project and interested in integrating SAM (https://segment-anything.com/) into Prodigy to speed up the labeling process. Specifically, I am trying to segment objects into multiple categories.

I am interested in using SAM because I think it will help me save time and improve accuracy by suggesting relevant segments using self-supervised learning methods. However, I am not sure how to integrate SAM with Prodigy and I would appreciate any guidance or tutorials that could help me get started.

Any guidance will be very much appreciated.

Many Thanks and Best Regards,
Bilal

1 Like

Hi Bilal,

this serves as a great reminder for me to check this out. It's not just LLMs that are making a big splash, this model also seems like a big deal.

I'll have a look at it to gauge how hard it is to make a prototype. I'll let you know if it's prohibitively hard or if it's do-able to set up locally.

Will let you know!

I've given it a spin locally by following the online readme. My local notebook sets everything up via:

import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

from segment_anything import sam_model_registry, SamPredictor

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    
    
sam = sam_model_registry["default"]("sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)

Then, I pass the predictor an image.

image = cv2.imread("../cat.png")
predictor.set_image(image)

On my 12 CPU intel NUC this takes 17 seconds. Which could be do-able when preparing the data in batch but it is prohibitive to do in real time.

Next, I mimic a mousepress by giving it coordinates.

input_point = np.array([[200, 175]])
input_label = np.array([1])

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  

Notice the star near the neck of the cat? We can pretend a mouse-click happened there.

CleanShot 2023-05-10 at 10.09.34

Then, we can ask the model to make a prediction.

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

Here's the prediction that it makes with the highest confidence score.

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  

CleanShot 2023-05-10 at 10.10.48

From a quick demo this indeed seems pretty impressive but there are a few bottlenecks to get this to work in Prodigy in the short term.

  1. Prodigy's image recipes don't allow for interaction with a backend at the moment. This is certainly something I'd love to revisit for Prodigy v2, but right now that means if you'd have to spin up a prediction server with a custom front-end in order to get this to work inside of Prodigy.
  2. The compute power that's needed here is prohibitively slow unless you have a good GPU. The site mentions it takes 0.15s on a GPU to prep an image. But without it; 17s seems reasonable to assume.
  3. I might need to double-check the license, as well as the training data, to fully understand when you're able to apply this on a custom dataset. The Github repository mentions Apache 2, but I'd like to double check the paper to make sure that the model weights have a permissive license.

An idea

The website suggests that there are other workflows possible too. For example; you seem to be able to use bounding boxes to allow SAM to turn that into a mask of the item that you might be interested in. If you have images annotated with a bounding box already then I can see how this could be helpful in turning those into masks. And you could do a two-step annotation scheme where the first step is drawing boxes and the second step is confirming the mask.

The website also suggests that you're able to provide text, like below, to select the items in a photo.

I haven't been able to find this API in their codebase though, but if it were available I can see how we might turn that into an easy recipe.

1 Like

Ah. And there is the confirmation, in the FAQ section:

CleanShot 2023-05-10 at 10.24.46

Text prompts are not released yet.