👀 Segment Something - A search engine within an image

Not too long ago, FAIR(Facebook AI Research) released a new model called Segment Anything Model (SAM) : a new AI model from Meta AI that can "cut out" any object, in any image, with a single click. It is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training.

With that all being said, it is a very powerful model that can be used for a variety of tasks. In this post, we will be using it along side with OpenCLIP to extract entities semantically from an image with a text prompt. At the end of this post, we will be able to extract entities from an image with a text prompt like this:

The source code for this post can be found here

Thought through process

To make an image semantically searchable with natural language, there are quite a few tasks need to be finished. The whole process is quite similar to the process of building a simple search engine. Indexing the existing entities, calculating the distance between user query and all entities, and then return the closest one.

To better illustrate the problem, I crafted a simple diagram to show the whole process:

There are mainly 4 steps:

  1. Segment image to entities - We will be using the SAM model to extract the entities from the image.
  2. Convert the entities into embeddings - We will be using OpenCLIP to convert the entities into embeddings.
  3. Convert the text prompt into embeddings - We will be using OpenCLIP to convert the text prompt into embeddings.
  4. Find the closest part and highlight - We will be using LanceDB to find the closest entity to the text prompt. Then apply some tricks with openCV to highlight the entity edge within the original image.

Solutions

1. Segment image to entities

To extract the entities from the image, we will be using the Segment Anything Model (SAM) from FAIR(Facebook AI Research). To setup SAM, we could use the following command:

# Create a SAM photon
lep photon create -n sam -m py:github.com/leptonai/examples.git:advanced/segment-anything/sam.py
# Run the SAM locally if you have a GPU
lep photon run -n sam --local
# Run the SAM remotely
lep photon run -n sam --resource-shape gpu.a10

Once the SAM is set up, we can use the following code to extract the entities from the image in forms of segmentations:

from leptonai.client import Client, local
from leptonai.photon.types import lepton_pickle, lepton_unpickle

LEPTON_API_TOKEN = "YOUR_LEPTON_API_TOKEN"
sam_client = Client("YOUR_WORKSPACE_ID", "SAM_DEPLOYMENT_NAME", token=LEPTON_API_TOKEN)
# or you could do sam_client = Client(local()) if running locally
def get_image_result_from_pickle(img_path, sam_client):
    input_img = cv2.imread(img_path)
    pickle_img = lepton_pickle(PIL_Img.fromarray(input_img))
    return sam_client.seg_from_pickle(image=pickle_img)

img_path = 'YOUR_IMAGE_PATH'
segmentations = lepton_unpickle(get_image_result_from_pickle(img_path, sam_client))

To get the image from a url, we could craft a download function to do so:

import uuid
import requests
import os

def download_image(url):
    response = requests.get(url)
    img_dir = uuid.uuid4().hex[:6]
    os.makedirs(img_dir)
    file_path = img_dir + '/' + 'index' + ".jpg"
    with open(file_path, "wb") as f:
        f.write(response.content)
    return img_dir

url = 'https://i.ibb.co/Z2ZMn1W/27191691674806-pic.jpg'
img_uuid = download_image(url)

2. Convert the entities into embeddings

To convert the entities into embeddings, we will be using CLIP model to do so. To setup CLIP, we could use the following command:

# Create a CLIP photon
lep photon create -n clip -m py:github.com/leptonai/examples.git:advanced/open-clip/open-clip.py
# Run the CLIP locally
lep photon run -n clip --local
# Run the CLIP remotely
lep photon run -n clip --resource-shape gpu.t4

Once the CLIP is set up, we can use the following code to convert the entities into embeddings:

from leptonai.client import Client, local
from leptonai.photon.types import lepton_pickle, lepton_unpickle
from PIL import Image as PIL_Img

LEPTON_API_TOKEN = "YOUR_LEPTON_API_TOKEN"
clip_client = Client("YOUR_WORKSPACE_ID", "CLIP_DEPLOYMENT_NAME", token=LEPTON_API_TOKEN)
# or you could do clip_client = Client(local()) if running locally

def get_image_embeddings_from_path(file_path, client):
    input_img = cv2.imread(file_path)
    pickle_img = lepton_pickle(PIL_Img.fromarray(input_img), compression=9)
    return client.embed_pickle_image(image=pickle_img)

img_path = 'YOUR_IMAGE_PATH'
embeddings = get_image_embeddings_from_path(img_path, clip_client)

3. Convert the text prompt into embeddings

To convert the text prompt into embeddings, we will be using CLIP model to do so as well. With the clip setup from the previous step, we can use the following code to convert the text prompt into embeddings:

clip_client.embed_text('YOUR_TEXT_PROMPT')

4. Find the closest part and highlight

For this step, we will be using the embeddings from step 2 and step 3 to find the closest part to the text prompt. Then we will be using openCV to highlight the entity edge within the original image. To find the closest part, we will be using LanceDB to do so. To setup LanceDB, we could use the following command:

import pandas as pd
import lancedb
uri = "data/sample-lancedb"
db = lancedb.connect(uri)

Once the LanceDB is set up, we can use the following code to build a function that takes in an image uuid(randomly generated during image download process) and then return a LanceDB table that could be used to find the closest part:

def cut_image_to_embeddings_vdb(img_uuid):
    img_path = img_uuid + '/index.jpg'
    source_img = cv2.imread(img_path)
    # Get segmentations from SAM mentioned in step 1
    segmentations = lepton_unpickle(get_image_result_from_pickle(img_path, sam_client))

    for index, seg in enumerate(segmentations):
        # Crop the image with the segmentation
        cropped_img = crop_image_with_bbox(crop_image_by_seg(source_img, seg['segmentation']), seg['bbox'])
        c_img_path = img_uuid + '/{}.jpg'.format(index)
        cv2.imwrite(c_img_path, cropped_img)
        # Get embeddings from CLIP mentioned in step 2
        embeddings = get_image_embeddings_from_path(c_img_path, clip_client)
        seg['embeddings'] = embeddings
        seg['img_path'] = c_img_path
        seg['seg_shape'] = seg['segmentation'].shape
        seg['segmentation'] = seg['segmentation'].reshape(-1)

    # Convert the segmentations into a dataframe
    seg_df = pd.DataFrame(segmentations)
    seg_df = seg_df[['img_path', 'embeddings', 'bbox', 'stability_score', 'predicted_iou', 'segmentation','seg_shape']]
    seg_df = seg_df.rename(columns={"embeddings": "vector"})
    # Create a LanceDB table with the dataframe
    tbl = db.create_table("table_{}".format(img_uuid), data=seg_df)
    return tbl

Then we can construct a function that takes in the table bult above, a text prompt and an image uuid, to highlight the closest part:

def find_part(vector_table, img_id, user_query):

    k_embedding = clip_client.embed_text(query=user_query)

    target = vector_table.search(k_embedding).limit(1).to_df()
    segmentation_mask = cv2.convertScaleAbs(target.iloc[0]['segmentation'].reshape(target.iloc[0]['seg_shape']).astype(int))

    # Dilate the segmentation mask to expand the area
    dilated_mask = cv2.dilate(segmentation_mask, np.ones((10,10), np.uint8), iterations=1)

    # Create a mask of the surroundings by subtracting the original segmentation mask
    surroundings_mask = dilated_mask - segmentation_mask

    # Create a highlighted version of the original image
    path = '{}/index.jpg'.format(img_id)
    highlighted_image = cv2.imread(path)
    highlighted_image[surroundings_mask > 0] = [218, 237, 247]

    cv2.imwrite('{}/processed.jpg'.format(img_id),highlighted_image)

    # Display the image
    display(Image(filename='{}/processed.jpg'.format(img_id)))

Result

With all the steps above, we can now extract entities from an image with a text prompt like this:

url = 'https://i.ibb.co/Z2ZMn1W/27191691674806-pic.jpg'
img_uuid = download_image(url)
tbl = cut_image_to_embeddings_vdb(img_uuid)
find_part(tbl,img_uuid, 'standing cat')

Then the standing cat will be highlighted from the original image.

Now we've successfully built a simple image search engine with natural language. With the help of SAM and CLIP, we can now extract entities from an image with a text prompt. With the help of LanceDB, we can find the closest part to the text prompt within few lines and retrive back the original segmentation at ease.

Yet that is not the end of the story. There are quite a few things we could try to further leverage the power of SAM and CLIP. One interesting use case to expand from here would be operating robotics with natural language as the paper VoxPoser mentioned.