Tejeshwar commited on
Commit
21922e9
Β·
verified Β·
1 Parent(s): 044738a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -20
app.py CHANGED
@@ -1,48 +1,37 @@
1
  import gradio as gr
2
- import torch
3
- import cv2
4
  import numpy as np
5
  from PIL import Image
6
- from pathlib import Path
7
- import subprocess
8
- import os
9
  from ultralytics import YOLO
 
10
 
11
- # βœ… Install flash_attn from local wheel (must be uploaded to the space)
12
- whl_file = "flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl"
13
- if os.path.exists(whl_file):
14
- subprocess.run(["pip", "install", whl_file])
15
- else:
16
- print("⚠️ .whl file for flash_attn not found. Please upload it to the space.")
17
-
18
- # βœ… Load YOLOv12 model
19
- model_path = Path("runs/detect/train7/weights/best.pt")
20
- assert model_path.exists(), "Model not found at runs/detect/train7/weights/best.pt"
21
  model = YOLO(str(model_path))
22
 
23
- # βœ… Detection function
24
  def detect_damage(image: Image.Image):
25
  results = model(image)
26
-
27
  is_damaged = False
28
  for result in results:
29
  for box in result.boxes:
30
  class_id = int(box.cls[0])
31
  label = model.names[class_id]
32
- print("Detected:", label)
33
  if "apple_damaged" in label.lower():
34
  is_damaged = True
35
  break
36
 
37
  return {"is_damaged": is_damaged}
38
 
39
- # βœ… Gradio UI
40
  demo = gr.Interface(
41
  fn=detect_damage,
42
  inputs=gr.Image(type="pil"),
43
  outputs=gr.JSON(label="Detection Result"),
44
  title="🍎 Apple Damage Detector",
45
- description="Upload an image of an apple to detect whether it is damaged using your trained YOLOv12 model."
46
  )
47
 
48
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import numpy as np
3
  from PIL import Image
 
 
 
4
  from ultralytics import YOLO
5
+ from pathlib import Path
6
 
7
+ # βœ… Load YOLOv12 model (make sure this path matches your repo)
8
+ model_path = Path("best.pt") # Or change to "weights/best.pt" if uploaded in subfolder
9
+ assert model_path.exists(), f"Model not found at {model_path}"
 
 
 
 
 
 
 
10
  model = YOLO(str(model_path))
11
 
12
+ # βœ… Damage detection using YOLOv12
13
  def detect_damage(image: Image.Image):
14
  results = model(image)
15
+
16
  is_damaged = False
17
  for result in results:
18
  for box in result.boxes:
19
  class_id = int(box.cls[0])
20
  label = model.names[class_id]
21
+ print("Detected label:", label)
22
  if "apple_damaged" in label.lower():
23
  is_damaged = True
24
  break
25
 
26
  return {"is_damaged": is_damaged}
27
 
28
+ # βœ… Gradio interface
29
  demo = gr.Interface(
30
  fn=detect_damage,
31
  inputs=gr.Image(type="pil"),
32
  outputs=gr.JSON(label="Detection Result"),
33
  title="🍎 Apple Damage Detector",
34
+ description="Upload an apple image to detect whether it is damaged using a YOLOv12 model."
35
  )
36
 
37
  demo.launch()