TheOneReborn's picture
fix: normalize image for output
2b8f512
raw
history blame
1.75 kB
import gradio
import numpy
from pathlib import Path
from PIL import Image
from fastai.vision.all import load_learner, PILImage, PILMask
MODEL_PATH = Path('.') / 'models'
TEST_IMAGES_PATH = Path('.') / 'test'
def preprocess_mask(file_name):
"""Ensures masks are in grayscale format and removes transparency."""
mask_path = Path(
'/kaggle/input/car-segmentation/car-segmentation/masks') / file_name.name
mask = Image.open(mask_path)
# Convert palette-based images to RGBA first to ensure proper color interpretation
if mask.mode == 'P':
mask = mask.convert('RGBA')
# Convert any non-RGBA images to RGBA
if mask.mode != 'RGBA':
mask = mask.convert('RGBA')
mask_data = mask.getdata()
# Replace fully transparent pixels with black (or another valid label)
new_mask_data = [
# Ensure full opacity in new mask
(r, g, b, 255) if a > 0 else (0, 0, 0, 255)
for r, g, b, a in mask_data
]
mask.putdata(new_mask_data)
# Convert to grayscale after handling transparency
return PILMask.create(mask.convert('L'))
LEARNER = load_learner(MODEL_PATH / 'car-segmentation_v1.pkl')
def segment_image(image):
image = PILImage.create(image)
prediction, _, _ = LEARNER.predict(image)
prediction_array = numpy.array(prediction, dtype=numpy.uint8)
# Normalize class indices to 0-255 for proper visualization
prediction_array = (prediction_array / prediction_array.max()) * 255
return prediction_array.astype(numpy.uint8)
demo = gradio.Interface(
segment_image,
inputs=gradio.Image(type='pil'),
outputs=gradio.Image(type='numpy'),
examples=[str(image) for image in TEST_IMAGES_PATH.iterdir()]
)
demo.launch()