TheOneReborn's picture
fix: typo in method
8103efe
raw
history blame contribute delete
2.45 kB
import gradio
import numpy
from matplotlib import cm
from pathlib import Path
from PIL import Image
from fastai.vision.all import load_learner, PILImage, PILMask
MODEL_PATH = Path('.') / 'models'
TEST_IMAGES_PATH = Path('.') / 'test'
def preprocess_mask(file_name):
"""Ensures masks are in grayscale format and removes transparency."""
mask_path = Path(
'/kaggle/inumpyut/car-segmentation/car-segmentation/masks') / file_name.name
mask = Image.open(mask_path)
if mask.mode == 'P':
mask = mask.convert('RGBA')
if mask.mode != 'RGBA':
mask = mask.convert('RGBA')
mask_data = mask.getdata()
new_mask_data = [
(r, g, b, 255) if a > 0 else (0, 0, 0, 255)
for r, g, b, a in mask_data
]
mask.putdata(new_mask_data)
return PILMask.create(mask.convert('L'))
LEARNER = load_learner(MODEL_PATH / 'car-segmentation_v1.pkl')
def segment_image(image):
# Store original size
original_size = image.size # (width, height)
# Resize the input image to 400x400 for the model
resized_image = image.resize((400, 400))
resized_image = PILImage.create(resized_image)
# Get the prediction from the model
prediction, _, _ = LEARNER.predict(resized_image)
# Convert prediction to a NumPy array
prediction_array = numpy.asarray(prediction, dtype=numpy.uint8)
# Resize the mask back to the original image size
prediction_resized = Image.fromarray(prediction_array).resize(original_size, Image.NEAREST)
prediction_resized = numpy.array(prediction_resized)
# Apply a colormap for visualization (using the public API)
colormap = cm._colormaps['jet']
# Normalize the mask and apply the colormap (result is in float [0,1])
colored_mask = colormap(prediction_resized / numpy.max(prediction_resized))[:, :, :3]
# Convert the original image to a NumPy array and normalize to [0,1]
image_array = numpy.array(image).astype(numpy.float32) / 255.0
# Blend the original image and the colored mask
overlay = (image_array * 0.7) + (colored_mask * 0.3)
# Convert the blended image back to uint8
overlay = (overlay * 255).astype(numpy.uint8)
return overlay
demo = gradio.Interface(
segment_image,
inputs=gradio.Image(type='pil'),
outputs=gradio.Image(type='numpy'),
examples=[str(image) for image in TEST_IMAGES_PATH.iterdir()]
)
demo.launch()