File size: 1,746 Bytes
4ae8579
 
 
 
8fc5d37
4ae8579
8fc5d37
4ae8579
 
 
 
 
d728d1b
8fc5d37
 
d728d1b
 
8fc5d37
 
 
 
d728d1b
 
8fc5d37
 
 
 
 
 
 
 
d728d1b
 
8fc5d37
 
d728d1b
8fc5d37
 
 
 
 
d728d1b
e8ef14f
 
8fc5d37
4ae8579
 
 
d728d1b
2b8f512
d728d1b
2b8f512
 
 
 
4ae8579
d728d1b
4ae8579
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio
import numpy

from pathlib import Path
from PIL import Image

from fastai.vision.all import load_learner, PILImage, PILMask


MODEL_PATH = Path('.') / 'models'
TEST_IMAGES_PATH = Path('.') / 'test'


def preprocess_mask(file_name):
    """Ensures masks are in grayscale format and removes transparency."""
    mask_path = Path(
        '/kaggle/input/car-segmentation/car-segmentation/masks') / file_name.name
    mask = Image.open(mask_path)

    # Convert palette-based images to RGBA first to ensure proper color interpretation
    if mask.mode == 'P':
        mask = mask.convert('RGBA')

    # Convert any non-RGBA images to RGBA
    if mask.mode != 'RGBA':
        mask = mask.convert('RGBA')

    mask_data = mask.getdata()

    # Replace fully transparent pixels with black (or another valid label)
    new_mask_data = [
        # Ensure full opacity in new mask
        (r, g, b, 255) if a > 0 else (0, 0, 0, 255)
        for r, g, b, a in mask_data
    ]

    mask.putdata(new_mask_data)

    # Convert to grayscale after handling transparency
    return PILMask.create(mask.convert('L'))


LEARNER = load_learner(MODEL_PATH / 'car-segmentation_v1.pkl')


def segment_image(image):
    image = PILImage.create(image)
    prediction, _, _ = LEARNER.predict(image)

    prediction_array = numpy.array(prediction, dtype=numpy.uint8)

    # Normalize class indices to 0-255 for proper visualization
    prediction_array = (prediction_array / prediction_array.max()) * 255

    return prediction_array.astype(numpy.uint8)


demo = gradio.Interface(
    segment_image,
    inputs=gradio.Image(type='pil'),
    outputs=gradio.Image(type='numpy'),
    examples=[str(image) for image in TEST_IMAGES_PATH.iterdir()]
)
demo.launch()