Khunanya's picture
Update app.py
13fb353 verified
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import numpy as np
# โหลดโมเดล SegFormer
processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
def segment(image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # (1, num_labels, H, W)
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1], # (H, W)
mode="bilinear",
align_corners=False,
)[0]
predicted = upsampled_logits.argmax(0).numpy()
# สร้างภาพ segmentation mask
colored_mask = Image.fromarray(segmentation_to_color(predicted))
return colored_mask
# แปลง mask เป็นสี (แบบง่าย)
def segmentation_to_color(segmentation):
num_classes = np.max(segmentation) + 1
colors = np.random.randint(0, 255, size=(num_classes, 3), dtype=np.uint8)
return colors[segmentation]
# Gradio UI
gr.Interface(
fn=segment,
inputs=gr.Image(type="pil", label="Upload an image"),
outputs=gr.Image(type="pil", label="Segmentation Mask"),
title="Semantic Segmentation with SegFormer",
description="ใช้โมเดล NVIDIA SegFormer สำหรับ Semantic Segmentation"
).launch()