navjotk commited on
Commit
2054fb5
·
verified ·
1 Parent(s): ae02cd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -11
app.py CHANGED
@@ -1,25 +1,69 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
- from PIL import Image
4
 
5
- # Load YOLOS pipeline
6
  detector = pipeline("object-detection", model="hustvl/yolos-small")
7
 
8
- def detect_objects(img):
9
- results = detector(img)
10
- boxes = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  for obj in results:
12
- label = f"{obj['label']} ({obj['score']:.2f})"
13
- box = obj['box']
14
- boxes.append((box["xmin"], box["ymin"], box["xmax"], box["ymax"], label))
15
- return img, boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  demo = gr.Interface(
18
- fn=detect_objects,
19
  inputs=gr.Image(type="pil"),
20
  outputs=gr.AnnotatedImage(),
21
  title="YOLOS Object Detection",
22
- description="Upload an image and detect objects using the YOLOS Transformer model.",
23
  )
24
 
25
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ from PIL import Image, ImageDraw, ImageFont
4
 
5
+ # Load YOLOS object detection model
6
  detector = pipeline("object-detection", model="hustvl/yolos-small")
7
 
8
+ # Confidence threshold (you can adjust this)
9
+ CONFIDENCE_THRESHOLD = 0.5
10
+
11
+ # Color palette for classes
12
+ COLORS = [
13
+ "red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"
14
+ ]
15
+
16
+ def get_color_for_label(label):
17
+ return COLORS[hash(label) % len(COLORS)]
18
+
19
+ def detect_and_draw(image):
20
+ results = detector(image)
21
+
22
+ # Draw boxes on image
23
+ image = image.convert("RGB")
24
+ draw = ImageDraw.Draw(image)
25
+
26
+ try:
27
+ font = ImageFont.truetype("arial.ttf", 16)
28
+ except:
29
+ font = ImageFont.load_default()
30
+
31
+ filtered_results = []
32
+
33
  for obj in results:
34
+ score = obj["score"]
35
+ if score < CONFIDENCE_THRESHOLD:
36
+ continue
37
+
38
+ label = f"{obj['label']} ({score:.2f})"
39
+ box = obj["box"]
40
+ color = get_color_for_label(obj["label"])
41
+
42
+ # Draw box
43
+ draw.rectangle(
44
+ [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])],
45
+ outline=color,
46
+ width=3,
47
+ )
48
+ # Draw label
49
+ draw.text(
50
+ (box["xmin"] + 5, box["ymin"] + 5),
51
+ label,
52
+ fill=color,
53
+ font=font
54
+ )
55
+
56
+ # Also return data for Gradio AnnotatedImage
57
+ filtered_results.append((box["xmin"], box["ymin"], box["xmax"], box["ymax"], label))
58
+
59
+ return image, filtered_results
60
 
61
  demo = gr.Interface(
62
+ fn=detect_and_draw,
63
  inputs=gr.Image(type="pil"),
64
  outputs=gr.AnnotatedImage(),
65
  title="YOLOS Object Detection",
66
+ description=f"Upload an image. Boxes show objects detected by the YOLOS Transformer model (confidence > {CONFIDENCE_THRESHOLD}).",
67
  )
68
 
69
  demo.launch()