Spaces:
Runtime error
Runtime error
File size: 4,263 Bytes
c304fb7 2c1d18b c304fb7 f1cff84 c304fb7 f1cff84 c304fb7 6a36cd0 f1cff84 7a8487d f1cff84 2c1d18b 8239775 c304fb7 6a36cd0 c304fb7 f1cff84 c304fb7 6a36cd0 f1cff84 6a36cd0 f1cff84 6a36cd0 c304fb7 f1cff84 6a36cd0 c304fb7 2c1d18b 8239775 c304fb7 f1cff84 8239775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import numpy as np
import gradio as gr
from PIL import Image
import torch
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
model_checkpoint = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8)
labels = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
# Draw the labels. Light colors use black text, dark colors use white text.
inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ]
labels_colored = []
for i in range(len(labels)):
r, g, b = palette[i]
label = labels[i]
color = "white" if i in inverted else "black"
text = "<span style='background-color: rgb(%d, %d, %d); color: %s; padding: 2px 4px;'>%s</span>" % (r, g, b, color, label)
labels_colored.append(text)
labels_text = " ".join(labels_colored)
title = "Semantic Segmentation with MobileViT and DeepLabV3"
description = """
The input image is resized and center cropped to 512Γ512 pixels. The segmentation output is 32Γ32 pixels.<br>
This model has been trained on <a href="http://host.robots.ox.ac.uk/pascal/VOC/">Pascal VOC</a>.
The classes are:
""" + labels_text + "</p>"
article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>π <a href="https://arxiv.org/abs/2110.02178">MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer</a></p>
<p>ποΈ Original pretrained weights from <a href="https://github.com/apple/ml-cvnets">this GitHub repo</a></p>
<p>π Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">this dataset</a><p>
</div>
"""
examples = [
["cat-3.jpg"],
["construction-site.jpg"],
["dog-cat.jpg"],
["football-match.jpg"],
]
def predict(image):
with torch.no_grad():
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
# Get preprocessed image. The pixel values don't need to be unnormalized
# for this particular model.
resized = (inputs["pixel_values"].numpy().squeeze().transpose(1, 2, 0)[..., ::-1] * 255).astype(np.uint8)
# Class predictions for each pixel.
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
# Super slow method but it works... should probably improve this.
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
for y in range(classes.shape[0]):
for x in range(classes.shape[1]):
colored[y, x] = palette[classes[y, x]]
# Resize predictions to input size (not original size).
colored = Image.fromarray(colored)
colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
# Keep everything that is not background.
mask = (classes != 0) * 255
mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
# Blend with the input image.
resized = Image.fromarray(resized)
highlighted = Image.blend(resized, mask, 0.4)
#colored = colored.resize((256, 256), resample=Image.Resampling.BICUBIC)
#highlighted = highlighted.resize((256, 256), resample=Image.Resampling.BICUBIC)
return colored, highlighted
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Upload image"),
outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Overlay")],
title=title,
description=description,
article=article,
examples=examples,
).launch()
|