TheOneReborn's picture
fix: remove deprecated method use
0342b70
raw
history blame
2.23 kB
import gradio
import numpy
from matplotlib import cm
from pathlib import Path
from PIL import Image
from fastai.vision.all import load_learner, PILImage, PILMask
MODEL_PATH = Path('.') / 'models'
TEST_IMAGES_PATH = Path('.') / 'test'
def preprocess_mask(file_name):
"""Ensures masks are in grayscale format and removes transparency."""
mask_path = Path(
'/kaggle/inumpyut/car-segmentation/car-segmentation/masks') / file_name.name
mask = Image.open(mask_path)
if mask.mode == 'P':
mask = mask.convert('RGBA')
if mask.mode != 'RGBA':
mask = mask.convert('RGBA')
mask_data = mask.getdata()
new_mask_data = [
(r, g, b, 255) if a > 0 else (0, 0, 0, 255)
for r, g, b, a in mask_data
]
mask.putdata(new_mask_data)
return PILMask.create(mask.convert('L'))
LEARNER = load_learner(MODEL_PATH / 'car-segmentation_v1.pkl')
def segment_image(image):
image = PILImage.create(image)
prediction, _, _ = LEARNER.predict(image)
print("Prediction shape:", prediction.shape)
print("Unique values:", numpy.unique(prediction))
# Convert prediction to NumPy array
prediction_array = numpy.asarray(prediction, dtype=numpy.uint8)
# Resize the mask to match the original image size
original_size = image.size # (width, height)
prediction_resized = Image.fromarray(prediction_array).resize(original_size, Image.NEAREST)
prediction_resized = numpy.array(prediction_resized)
# Apply a colormap for visualization
colormap = cm.colormaps['jet']
colored_mask = colormap(prediction_resized / numpy.max(prediction_resized))[:, :, :3] # Normalize & remove alpha
# Convert PIL image to NumPy array
image_array = numpy.array(image).astype(numpy.float32) / 255.0 # Normalize to [0,1]
# Blend the original image and the mask
overlay = (image_array * 0.7) + (colored_mask * 0.3)
# Convert back to [0,255] uint8
overlay = (overlay * 255).astype(numpy.uint8)
return overlay
demo = gradio.Interface(
segment_image,
inputs=gradio.Image(type='pil'),
outputs=gradio.Image(type='numpy'),
examples=[str(image) for image in TEST_IMAGES_PATH.iterdir()]
)
demo.launch()