Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,11 @@ from PIL import Image, ImageDraw, ImageFont
|
|
5 |
# Load YOLOS object detection model
|
6 |
detector = pipeline("object-detection", model="hustvl/yolos-small")
|
7 |
|
8 |
-
# Confidence threshold
|
9 |
CONFIDENCE_THRESHOLD = 0.5
|
10 |
|
11 |
-
# Color palette
|
12 |
-
COLORS = [
|
13 |
-
"red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"
|
14 |
-
]
|
15 |
|
16 |
def get_color_for_label(label):
|
17 |
return COLORS[hash(label) % len(COLORS)]
|
@@ -19,17 +17,17 @@ def get_color_for_label(label):
|
|
19 |
def detect_and_draw(image):
|
20 |
results = detector(image)
|
21 |
|
22 |
-
#
|
23 |
image = image.convert("RGB")
|
24 |
draw = ImageDraw.Draw(image)
|
25 |
-
|
26 |
try:
|
27 |
font = ImageFont.truetype("arial.ttf", 16)
|
28 |
except:
|
29 |
font = ImageFont.load_default()
|
30 |
|
31 |
-
|
32 |
-
|
33 |
for obj in results:
|
34 |
score = obj["score"]
|
35 |
if score < CONFIDENCE_THRESHOLD:
|
@@ -45,6 +43,7 @@ def detect_and_draw(image):
|
|
45 |
outline=color,
|
46 |
width=3,
|
47 |
)
|
|
|
48 |
# Draw label
|
49 |
draw.text(
|
50 |
(box["xmin"] + 5, box["ymin"] + 5),
|
@@ -53,17 +52,19 @@ def detect_and_draw(image):
|
|
53 |
font=font
|
54 |
)
|
55 |
|
56 |
-
#
|
57 |
-
|
|
|
58 |
|
59 |
-
return image,
|
60 |
|
|
|
61 |
demo = gr.Interface(
|
62 |
fn=detect_and_draw,
|
63 |
inputs=gr.Image(type="pil"),
|
64 |
outputs=gr.AnnotatedImage(),
|
65 |
title="YOLOS Object Detection",
|
66 |
-
description=f"Upload an image
|
67 |
)
|
68 |
|
69 |
demo.launch()
|
|
|
5 |
# Load YOLOS object detection model
|
6 |
detector = pipeline("object-detection", model="hustvl/yolos-small")
|
7 |
|
8 |
+
# Confidence threshold
|
9 |
CONFIDENCE_THRESHOLD = 0.5
|
10 |
|
11 |
+
# Color palette
|
12 |
+
COLORS = ["red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"]
|
|
|
|
|
13 |
|
14 |
def get_color_for_label(label):
|
15 |
return COLORS[hash(label) % len(COLORS)]
|
|
|
17 |
def detect_and_draw(image):
|
18 |
results = detector(image)
|
19 |
|
20 |
+
# Convert to RGB for drawing
|
21 |
image = image.convert("RGB")
|
22 |
draw = ImageDraw.Draw(image)
|
23 |
+
|
24 |
try:
|
25 |
font = ImageFont.truetype("arial.ttf", 16)
|
26 |
except:
|
27 |
font = ImageFont.load_default()
|
28 |
|
29 |
+
annotations = []
|
30 |
+
|
31 |
for obj in results:
|
32 |
score = obj["score"]
|
33 |
if score < CONFIDENCE_THRESHOLD:
|
|
|
43 |
outline=color,
|
44 |
width=3,
|
45 |
)
|
46 |
+
|
47 |
# Draw label
|
48 |
draw.text(
|
49 |
(box["xmin"] + 5, box["ymin"] + 5),
|
|
|
52 |
font=font
|
53 |
)
|
54 |
|
55 |
+
# AnnotatedImage expects (box_tuple, label)
|
56 |
+
box_coords = (box["xmin"], box["ymin"], box["xmax"], box["ymax"])
|
57 |
+
annotations.append((box_coords, label))
|
58 |
|
59 |
+
return image, annotations
|
60 |
|
61 |
+
# Gradio interface
|
62 |
demo = gr.Interface(
|
63 |
fn=detect_and_draw,
|
64 |
inputs=gr.Image(type="pil"),
|
65 |
outputs=gr.AnnotatedImage(),
|
66 |
title="YOLOS Object Detection",
|
67 |
+
description=f"Upload an image to detect objects using the YOLOS model. Only objects with confidence > {CONFIDENCE_THRESHOLD} are shown.",
|
68 |
)
|
69 |
|
70 |
demo.launch()
|