Tejeshwar commited on
Commit
dae7c61
Β·
verified Β·
1 Parent(s): 7446045

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -38
app.py CHANGED
@@ -1,38 +1,48 @@
1
- import gradio as gr
2
- import torch
3
- import cv2
4
- import numpy as np
5
- from PIL import Image
6
- from pathlib import Path
7
- import subprocess
8
- import os
9
-
10
- # βœ… Install flash_attn from local wheel (must be uploaded to the space)
11
- whl_file = "flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl"
12
- if os.path.exists(whl_file):
13
- subprocess.run(["pip", "install", whl_file])
14
- else:
15
- print("⚠️ .whl file for flash_attn not found. Please upload it to the space.")
16
-
17
- # Load YOLOv12 model (placeholder logic below)
18
- model_path = Path("best.pt")
19
- assert model_path.exists(), "best.pt model not found. Upload your trained YOLOv12 model."
20
-
21
- # Simulate a dummy model detection for placeholder purposes
22
- def detect_damage(image: Image.Image):
23
- img = np.array(image.convert("RGB"))
24
- gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
25
- std_dev = np.std(gray)
26
- is_damaged = std_dev > 25 # Replace with actual model logic
27
- return {"is_damaged": is_damaged}
28
-
29
- # Gradio Interface
30
- demo = gr.Interface(
31
- fn=detect_damage,
32
- inputs=gr.Image(type="pil"),
33
- outputs=gr.JSON(label="Detection Result"),
34
- title="🍎 Apple Damage Detector",
35
- description="Upload an image of an apple to detect whether it is damaged using a YOLOv12 model."
36
- )
37
-
38
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ import subprocess
8
+ import os
9
+ from ultralytics import YOLO
10
+
11
+ # βœ… Install flash_attn from local wheel (must be uploaded to the space)
12
+ whl_file = "flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl"
13
+ if os.path.exists(whl_file):
14
+ subprocess.run(["pip", "install", whl_file])
15
+ else:
16
+ print("⚠️ .whl file for flash_attn not found. Please upload it to the space.")
17
+
18
+ # βœ… Load YOLOv12 model
19
+ model_path = Path("runs/detect/train7/weights/best.pt")
20
+ assert model_path.exists(), "Model not found at runs/detect/train7/weights/best.pt"
21
+ model = YOLO(str(model_path))
22
+
23
+ # βœ… Detection function
24
+ def detect_damage(image: Image.Image):
25
+ results = model(image)
26
+
27
+ is_damaged = False
28
+ for result in results:
29
+ for box in result.boxes:
30
+ class_id = int(box.cls[0])
31
+ label = model.names[class_id]
32
+ print("Detected:", label)
33
+ if "apple_damaged" in label.lower():
34
+ is_damaged = True
35
+ break
36
+
37
+ return {"is_damaged": is_damaged}
38
+
39
+ # βœ… Gradio UI
40
+ demo = gr.Interface(
41
+ fn=detect_damage,
42
+ inputs=gr.Image(type="pil"),
43
+ outputs=gr.JSON(label="Detection Result"),
44
+ title="🍎 Apple Damage Detector",
45
+ description="Upload an image of an apple to detect whether it is damaged using your trained YOLOv12 model."
46
+ )
47
+
48
+ demo.launch()