mohit-mahavar's picture
Added app and requirements file
d9db56e
raw
history blame
6.16 kB
from torch import nn
import numpy as np
from PIL import Image
import gradio as gr
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
def ade_palette():
"""ADE20K palette that maps each class to RGB values."""
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]
def resize_image(image, new_size, sdxl_resize=None):
"""
Resizes the given image while maintaining its aspect ratio.
Args:
image (PIL.Image): The image to be resized.
new_size (int): The new size (width or height) to resize the image to.
sdxl_resize (bool, optional): Flag indicating whether to resize based on \
the larger dimension. Default is None.
Returns:
PIL.Image: The resized image.
"""
original_width, original_height = image.size
if sdxl_resize:
value = max(original_height, original_width)
else:
value = min(original_height, original_width)
# Determine which side to fix based on minimum width or height
if value == original_height:
aspect_ratio = original_width / original_height
new_height = new_size
new_width = int(new_height * aspect_ratio)
else:
aspect_ratio = original_height / original_width
new_width = new_size
new_height = int(new_width * aspect_ratio)
resized_image = image.resize((new_width, new_height))
# Ensure that both dimensions are multiples of 64
w, h = resized_image.size
w, h = map(lambda x: x - x % 64, (w, h))
resized_image = resized_image.resize((w, h))
return resized_image
def run(img):
extractor = AutoFeatureExtractor.from_pretrained("mohit-mahavar/segformer-b0-finetuned-segments-sidewalk-july-24")
model = SegformerForSemanticSegmentation.from_pretrained("mohit-mahavar/segformer-b0-finetuned-segments-sidewalk-july-24")
if min(img.size) >= 768:
img = resize_image(img, 768)
elif max(img.size) >= 1024:
img = resize_image(img, 1024, sdxl_resize=True)
elif min(img.size) >= 512:
img = resize_image(img, 512)
elif max(img.size) >= 768:
img = resize_image(img, 768, sdxl_resize=True)
elif max(img.size) >= 512:
img = resize_image(img, 512, sdxl_resize=True)
pixel_values = extractor(img, return_tensors="pt").pixel_values.to("cpu")
outputs = model(pixel_values)
logits = outputs.logits
logits = nn.functional.interpolate(outputs.logits.detach().cpu(),
size=img.size[::-1], # (height, width)
mode='bilinear',
align_corners=False)
# Second, apply argmax on the class dimension
seg = logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
img = np.array(img) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
img = Image.fromarray(img)
return img
# Create a Gradio interface
iface = gr.Interface(
fn=run,
inputs=gr.Image(label="Input image", type="pil"),
examples=["1.jpg" , "2.jpg"] ,
outputs=gr.Image(label="Output image with predicted instance Masks", type="pil"),
title="Image Segmentation with Segments-Sidewalk-SegFormer-B0",
description="Upload an image, and this app will perform image segmentation and display the result",
)
# Launch the app
iface.launch(debug=True)