dalybuilds commited on
Commit
b5fe3de
·
verified ·
1 Parent(s): f86dd3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -25
app.py CHANGED
@@ -1,47 +1,39 @@
1
  import gradio as gr
2
- from transformers import DetrFeatureExtractor, DetrForObjectDetection, DetrImageProcessor # Use DetrImageProcessor
3
  import torch
4
- from PIL import Image
5
 
6
- # Load the model and processor (suppressing warnings and using DetrImageProcessor)
7
  try:
8
- feature_extractor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") # Correct image processor
9
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", ignore_mismatched_sizes=True) # Suppress warnings
10
- except Exception as e: # Catch potential errors during model loading
11
- print(f"Error loading model: {e}") # So you can see in the logs what’s happening
12
- raise e
13
-
14
 
15
  def predict(image):
16
-
17
  inputs = feature_extractor(images=image, return_tensors="pt")
18
  outputs = model(**inputs)
19
 
20
-
21
  target_sizes = torch.tensor([image.size[::-1]])
22
- results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Simplify this line
23
 
24
- potholes = []
 
25
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
26
- box = [round(i, 2) for i in box.tolist()]
27
- potholes.append({
28
- "box": box,
29
- "score": round(score.item(), 3),
30
- "label": model.config.id2label[label.item()]
31
- })
32
-
33
- return potholes
34
-
35
-
36
 
 
37
 
 
38
  iface = gr.Interface(
39
  fn=predict,
40
  inputs=gr.Image(type="pil"),
41
- outputs=gr.JSON(label="Detected Potholes"), # Correct for Gradio 3.0+
42
  title="Pothole Detection POC",
43
  description="Upload an image to detect potholes."
44
  )
45
 
46
-
47
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
  import torch
4
+ from PIL import Image, ImageDraw
5
 
6
+ # Model loading (same as before - with error handling)
7
  try:
8
+ feature_extractor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
9
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", ignore_mismatched_sizes=True)
10
+ except Exception as e: # Error handling during model loading
11
+ print(f"Error loading model: {e}") # Log the error so you can see in HF logs
12
+ raise e # Re-raise for Space to report it
 
13
 
14
  def predict(image):
 
15
  inputs = feature_extractor(images=image, return_tensors="pt")
16
  outputs = model(**inputs)
17
 
 
18
  target_sizes = torch.tensor([image.size[::-1]])
19
+ results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
20
 
21
+ # Draw bounding boxes on the image
22
+ draw = ImageDraw.Draw(image) # Create a drawing object
23
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
24
+ box = [round(i) for i in box.tolist()] # Convert to integers for drawing
25
+ draw.rectangle(box, outline="red", width=2) # Outline
26
+ draw.text((box[0], box[1]), model.config.id2label[label.item()], fill="red") # Add a label
 
 
 
 
 
 
 
27
 
28
+ return image # Return the image with the bounding boxes drawn
29
 
30
+ # Gradio Interface (updated output type)
31
  iface = gr.Interface(
32
  fn=predict,
33
  inputs=gr.Image(type="pil"),
34
+ outputs=gr.Image(type="pil", label="Detected Potholes (Image)"), # Updated
35
  title="Pothole Detection POC",
36
  description="Upload an image to detect potholes."
37
  )
38
 
 
39
  iface.launch()