syedfaisalabrar commited on
Commit
4895b5f
·
verified ·
1 Parent(s): 479a272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -49,7 +49,7 @@ def preprocessing(image):
49
  image = ImageEnhance.Brightness(image).enhance(0.8) # Reduce brightness
50
 
51
  # Convert to tensor without resizing
52
- image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 # Shape: [C, H, W]
53
 
54
  return image_tensor
55
 
@@ -64,21 +64,13 @@ def imageRotation(image):
64
 
65
  def detect_document(image):
66
  """Detects front and back of the document using YOLO."""
67
- image = np.array(image)
68
  results = modelY(image, conf=0.85)
69
 
70
  detected_classes = set()
71
  labels = []
72
  bounding_boxes = []
73
-
74
- if isinstance(image, np.ndarray):
75
- if image.dtype != np.uint8:
76
- image = (image * 255).clip(0, 255).astype(np.uint8) # Convert float to uint8
77
-
78
- # Ensure correct shape (H, W, C)
79
- if image.shape[0] == 1 and image.shape[1] == 1:
80
- image = np.squeeze(image)
81
-
82
  for result in results:
83
  for box in result.boxes:
84
  x1, y1, x2, y2 = map(int, box.xyxy[0])
@@ -89,8 +81,9 @@ def detect_document(image):
89
  detected_classes.add(class_name)
90
  label = f"{class_name} {conf:.2f}"
91
  labels.append(label)
92
- bounding_boxes.append((x1, y1, x2, y2, class_name, conf)) # Store bounding box with class and confidence
93
 
 
94
  cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
95
  cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
96
 
@@ -99,17 +92,21 @@ def detect_document(image):
99
  if missing_classes:
100
  labels.append(f"Missing: {', '.join(missing_classes)}")
101
 
102
- return Image.fromarray(image), labels, bounding_boxes
103
 
104
 
105
  def crop_image(image, bounding_boxes):
106
- """Crops detected bounding boxes from the image."""
 
107
  cropped_images = {}
108
- image = np.array(image)
109
 
110
  for (x1, y1, x2, y2, class_name, conf) in bounding_boxes:
 
 
111
  cropped = image[y1:y2, x1:x2]
112
- cropped_images[class_name] = Image.fromarray(cropped)
 
 
113
 
114
  return cropped_images
115
 
@@ -136,30 +133,34 @@ def ensure_numpy(image):
136
  # Convert grayscale to 3-channel image
137
  image = np.stack([image] * 3, axis=-1)
138
 
139
- return image
 
140
 
141
  def predict(image):
142
  """Pipeline: Preprocess -> Detect -> Crop -> Vision AI API."""
143
- processed_image = preprocessing(image)
144
- rotated_image = ensure_numpy(processed_image)
145
  detected_image, labels, bounding_boxes = detect_document(rotated_image)
146
 
 
 
 
147
  cropped_images = crop_image(rotated_image, bounding_boxes)
148
 
149
  # Call Vision AI separately for front and back if detected
150
- front_result, back_result = None, None
151
  if "front" in cropped_images:
152
  front_result = vision_ai_api(cropped_images["front"], "front")
153
  if "back" in cropped_images:
154
  back_result = vision_ai_api(cropped_images["back"], "back")
155
 
156
-
157
  api_results = {
158
  "front": front_result,
159
  "back": back_result
160
  }
161
- single_image = cropped_images.get("front") or cropped_images.get("back") or detected_image
162
- return single_image, labels, api_results
 
163
 
164
 
165
  iface = gr.Interface(
 
49
  image = ImageEnhance.Brightness(image).enhance(0.8) # Reduce brightness
50
 
51
  # Convert to tensor without resizing
52
+ # image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 # Shape: [C, H, W]
53
 
54
  return image_tensor
55
 
 
64
 
65
  def detect_document(image):
66
  """Detects front and back of the document using YOLO."""
67
+ image = ensure_numpy(image) # Ensure valid format
68
  results = modelY(image, conf=0.85)
69
 
70
  detected_classes = set()
71
  labels = []
72
  bounding_boxes = []
73
+
 
 
 
 
 
 
 
 
74
  for result in results:
75
  for box in result.boxes:
76
  x1, y1, x2, y2 = map(int, box.xyxy[0])
 
81
  detected_classes.add(class_name)
82
  label = f"{class_name} {conf:.2f}"
83
  labels.append(label)
84
+ bounding_boxes.append((x1, y1, x2, y2, class_name, conf))
85
 
86
+ # Draw bounding box
87
  cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
88
  cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
89
 
 
92
  if missing_classes:
93
  labels.append(f"Missing: {', '.join(missing_classes)}")
94
 
95
+ return Image.fromarray(image.astype(np.uint8)), labels, bounding_boxes
96
 
97
 
98
  def crop_image(image, bounding_boxes):
99
+ """Crops detected bounding boxes from the image safely."""
100
+ image = ensure_numpy(image) # Ensure image is NumPy format
101
  cropped_images = {}
 
102
 
103
  for (x1, y1, x2, y2, class_name, conf) in bounding_boxes:
104
+ # Ensure the bounding box is within image bounds
105
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(image.shape[1], x2), min(image.shape[0], y2)
106
  cropped = image[y1:y2, x1:x2]
107
+
108
+ if cropped.size > 0: # Check if valid
109
+ cropped_images[class_name] = Image.fromarray(cropped)
110
 
111
  return cropped_images
112
 
 
133
  # Convert grayscale to 3-channel image
134
  image = np.stack([image] * 3, axis=-1)
135
 
136
+ # return image
137
+ return image.astype(np.uint8)
138
 
139
  def predict(image):
140
  """Pipeline: Preprocess -> Detect -> Crop -> Vision AI API."""
141
+ processed_image = preprocessing(image) # Enhanced PIL image
142
+ rotated_image = ensure_numpy(processed_image) # Convert to NumPy
143
  detected_image, labels, bounding_boxes = detect_document(rotated_image)
144
 
145
+ if not bounding_boxes:
146
+ return detected_image, labels, {"error": "No document detected!"}
147
+
148
  cropped_images = crop_image(rotated_image, bounding_boxes)
149
 
150
  # Call Vision AI separately for front and back if detected
151
+ front_result = back_result = None
152
  if "front" in cropped_images:
153
  front_result = vision_ai_api(cropped_images["front"], "front")
154
  if "back" in cropped_images:
155
  back_result = vision_ai_api(cropped_images["back"], "back")
156
 
 
157
  api_results = {
158
  "front": front_result,
159
  "back": back_result
160
  }
161
+
162
+ return detected_image, labels, api_results
163
+
164
 
165
 
166
  iface = gr.Interface(