navjotk commited on
Commit
23e2414
·
verified ·
1 Parent(s): 2054fb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -5,13 +5,11 @@ from PIL import Image, ImageDraw, ImageFont
5
  # Load YOLOS object detection model
6
  detector = pipeline("object-detection", model="hustvl/yolos-small")
7
 
8
- # Confidence threshold (you can adjust this)
9
  CONFIDENCE_THRESHOLD = 0.5
10
 
11
- # Color palette for classes
12
- COLORS = [
13
- "red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"
14
- ]
15
 
16
  def get_color_for_label(label):
17
  return COLORS[hash(label) % len(COLORS)]
@@ -19,17 +17,17 @@ def get_color_for_label(label):
19
  def detect_and_draw(image):
20
  results = detector(image)
21
 
22
- # Draw boxes on image
23
  image = image.convert("RGB")
24
  draw = ImageDraw.Draw(image)
25
-
26
  try:
27
  font = ImageFont.truetype("arial.ttf", 16)
28
  except:
29
  font = ImageFont.load_default()
30
 
31
- filtered_results = []
32
-
33
  for obj in results:
34
  score = obj["score"]
35
  if score < CONFIDENCE_THRESHOLD:
@@ -45,6 +43,7 @@ def detect_and_draw(image):
45
  outline=color,
46
  width=3,
47
  )
 
48
  # Draw label
49
  draw.text(
50
  (box["xmin"] + 5, box["ymin"] + 5),
@@ -53,17 +52,19 @@ def detect_and_draw(image):
53
  font=font
54
  )
55
 
56
- # Also return data for Gradio AnnotatedImage
57
- filtered_results.append((box["xmin"], box["ymin"], box["xmax"], box["ymax"], label))
 
58
 
59
- return image, filtered_results
60
 
 
61
  demo = gr.Interface(
62
  fn=detect_and_draw,
63
  inputs=gr.Image(type="pil"),
64
  outputs=gr.AnnotatedImage(),
65
  title="YOLOS Object Detection",
66
- description=f"Upload an image. Boxes show objects detected by the YOLOS Transformer model (confidence > {CONFIDENCE_THRESHOLD}).",
67
  )
68
 
69
  demo.launch()
 
5
  # Load YOLOS object detection model
6
  detector = pipeline("object-detection", model="hustvl/yolos-small")
7
 
8
+ # Confidence threshold
9
  CONFIDENCE_THRESHOLD = 0.5
10
 
11
+ # Color palette
12
+ COLORS = ["red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"]
 
 
13
 
14
  def get_color_for_label(label):
15
  return COLORS[hash(label) % len(COLORS)]
 
17
  def detect_and_draw(image):
18
  results = detector(image)
19
 
20
+ # Convert to RGB for drawing
21
  image = image.convert("RGB")
22
  draw = ImageDraw.Draw(image)
23
+
24
  try:
25
  font = ImageFont.truetype("arial.ttf", 16)
26
  except:
27
  font = ImageFont.load_default()
28
 
29
+ annotations = []
30
+
31
  for obj in results:
32
  score = obj["score"]
33
  if score < CONFIDENCE_THRESHOLD:
 
43
  outline=color,
44
  width=3,
45
  )
46
+
47
  # Draw label
48
  draw.text(
49
  (box["xmin"] + 5, box["ymin"] + 5),
 
52
  font=font
53
  )
54
 
55
+ # AnnotatedImage expects (box_tuple, label)
56
+ box_coords = (box["xmin"], box["ymin"], box["xmax"], box["ymax"])
57
+ annotations.append((box_coords, label))
58
 
59
+ return image, annotations
60
 
61
+ # Gradio interface
62
  demo = gr.Interface(
63
  fn=detect_and_draw,
64
  inputs=gr.Image(type="pil"),
65
  outputs=gr.AnnotatedImage(),
66
  title="YOLOS Object Detection",
67
+ description=f"Upload an image to detect objects using the YOLOS model. Only objects with confidence > {CONFIDENCE_THRESHOLD} are shown.",
68
  )
69
 
70
  demo.launch()