Code for How to Perform Image Segmentation using Transformers in Python Tutorial


View on Github

Open In Colab

image_segmentation_transformers.py

# %% [markdown]
# # Set up environment

# %%
!pip install transformers

# %%
from IPython.display import clear_output
# !pip3 install transformers
clear_output()

# %%
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import pipeline, SegformerImageProcessor, SegformerForSemanticSegmentation
import requests
from PIL import Image
import urllib.parse as parse
import os

# %%
# a function to determine whether a string is a URL or not
def is_url(string):
    try:
        result = parse.urlparse(string)
        return all([result.scheme, result.netloc, result.path])
    except:
        return False

# a function to load an image
def load_image(image_path):
    """Helper function to load images from their URLs or paths."""
    if is_url(image_path):
        return Image.open(requests.get(image_path, stream=True).raw)
    elif os.path.exists(image_path):
        return Image.open(image_path)

# %% [markdown]
# # Load Image

# %%
img_path = "https://shorthaircatbreeds.com/wp-content/uploads/2020/06/Urban-cat-crossing-a-road-300x180.jpg"
image = load_image(img_path)

# %%
image

# %%
# convert PIL Image to pytorch tensors
transform = transforms.ToTensor()
image_tensor = image.convert("RGB")
image_tensor = transform(image_tensor)
image_tensor.shape

# %% [markdown]
# # Helper functions

# %%
def color_palette():
  """Color palette to map each class to its corresponding color."""
  return [[0, 128, 128],
          [255, 170, 0],
          [161, 19, 46],
          [118, 171, 47],
          [255, 255, 0],
          [84, 170, 127],
          [170, 84, 127],
          [33, 138, 200],
          [255, 84, 0],
          [255, 140, 208]]

# %%
def overlay_segments(image, seg_mask):
  """Return different segments predicted by the model overlaid on image."""
  H, W = seg_mask.shape
  image_mask = np.zeros((H, W, 3), dtype=np.uint8)
  colors = np.array(color_palette())

  # convert to a pytorch tensor if seg_mask is not one already
  seg_mask = seg_mask if torch.is_tensor(seg_mask) else torch.tensor(seg_mask)
  unique_labels = torch.unique(seg_mask)

  # map each segment label to a unique color
  for i, label in enumerate(unique_labels):
    image_mask[seg_mask == label.item(), :] = colors[i]

  image = np.array(image)
  # percentage of original image in the final overlaid iamge
  img_weight = 0.5 

  # overlay input image and the generated segment mask
  img = img_weight * np.array(image) * 255 + (1 - img_weight) * image_mask

  return img.astype(np.uint8)

# %%
def replace_label(mask, label):
  """Replace the segment masks values with label."""
  mask = np.array(mask)
  mask[mask == 255] = label
  return mask

# %% [markdown]
# # Image segmentation using Hugging Face Pipeline

# %%
# load the entire image segmentation pipeline
img_segmentation_pipeline = pipeline('image-segmentation', 
                                     model="nvidia/segformer-b5-finetuned-ade-640-640")

# %%
output = img_segmentation_pipeline(image)
output

# %%
output[0]['mask']

# %%
output[2]['mask']

# %%
# load the feature extractor (to preprocess images) and the model (to get outputs)
W, H = image.size
segmentation_mask = np.zeros((H, W), dtype=np.uint8)

for i in range(len(output)):
  segmentation_mask += replace_label(output[i]['mask'], i)

# %%
# overlay the predicted segmentation masks on the original image
segmented_img = overlay_segments(image_tensor.permute(1, 2, 0), segmentation_mask)

# convert to PIL Image
Image.fromarray(segmented_img)

# %% [markdown]
# # Image segmentation using custom Hugging Face models

# %%
# load the feature extractor (to preprocess images) and the model (to get outputs)
feature_extractor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640")

# %%
def to_tensor(image):
  """Convert PIL Image to pytorch tensor."""
  transform = transforms.ToTensor()
  image_tensor = image.convert("RGB")
  image_tensor = transform(image_tensor)
  return image_tensor

# a function that takes an image and return the segmented image
def get_segmented_image(model, feature_extractor, image_path):
  """Return the predicted segmentation mask for the input image."""
  # load the image
  image = load_image(image_path)
  # preprocess input
  inputs = feature_extractor(images=image, return_tensors="pt")
  # convert to pytorch tensor
  image_tensor = to_tensor(image)
  # pass the processed input to the model
  outputs = model(**inputs)
  print("outputs.logits.shape:", outputs.logits.shape)
  # interpolate output logits to the same shape as the input image
  upsampled_logits = F.interpolate(
      outputs.logits, # tensor to be interpolated
      size=image_tensor.shape[1:], # output size we want
      mode='bilinear', # do bilinear interpolation
      align_corners=False)

  # get the class with max probabilities
  segmentation_mask = upsampled_logits.argmax(dim=1)[0]
  print(f"{segmentation_mask.shape=}")
  # get the segmented image
  segmented_img = overlay_segments(image_tensor.permute(1, 2, 0), segmentation_mask)
  # convert to PIL Image
  return Image.fromarray(segmented_img)

# %%
get_segmented_image(model, feature_extractor, "https://shorthaircatbreeds.com/wp-content/uploads/2020/06/Urban-cat-crossing-a-road-300x180.jpg")

# %%
get_segmented_image(model, feature_extractor, "http://images.cocodataset.org/test-stuff2017/000000000001.jpg")

# %%