satvs commited on
Commit
1288dc4
·
1 Parent(s): 8c7ef1f

Optimize submission

Browse files
Files changed (1) hide show
  1. tasks/image.py +42 -24
tasks/image.py CHANGED
@@ -18,6 +18,7 @@ from pathlib import Path
18
  from ultralytics import YOLO
19
  from torch import device
20
  from torch.cuda import is_available
 
21
 
22
  router = APIRouter()
23
 
@@ -115,35 +116,52 @@ async def evaluate_image(request: ImageEvaluationRequest):
115
  model = YOLO(Path(model_path, model_name), task="detect")
116
  device_name = device("cuda" if is_available() else "cpu")
117
 
 
 
 
 
 
 
 
118
  predictions = []
119
  true_labels = []
120
  pred_boxes = []
121
  true_boxes_list = [] # List of lists, each inner list contains boxes for one image
122
-
123
  logging.info(f"Inference start on device: {device_name}")
124
- for example in test_dataset:
125
- # Parse true annotation (YOLO format: class_id x_center y_center width height)
126
- annotation = example.get("annotations", "").strip()
127
- has_smoke = len(annotation) > 0
128
- true_labels.append(int(has_smoke))
129
-
130
- # Make prediction
131
- results = model.predict(example["image"], device=device_name, conf=THRESHOLD, verbose=False, half=True, imgsz=IMGSIZE)[0]
132
- pred_has_smoke = len(results) > 0
133
- predictions.append(int(pred_has_smoke))
134
-
135
- # If there's a true box, parse it and add box prediction
136
- if has_smoke:
137
- # Parse all true boxes from the annotation
138
- image_true_boxes = parse_boxes(annotation)
139
- true_boxes_list.append(image_true_boxes)
140
-
141
- # Append only one bounding box if at least one fire is detected
142
- # Note that multiple boxes could be appended
143
- if results.boxes.cls.numel()!=0:
144
- pred_boxes.append(results.boxes[0].xywhn.tolist()[0])
145
- else:
146
- pred_boxes.append([0,0,0,0])
 
 
 
 
 
 
 
 
 
 
147
 
148
  #--------------------------------------------------------------------------------------------
149
  # YOUR MODEL INFERENCE STOPS HERE
 
18
  from ultralytics import YOLO
19
  from torch import device
20
  from torch.cuda import is_available
21
+ from torch import no_grad
22
 
23
  router = APIRouter()
24
 
 
116
  model = YOLO(Path(model_path, model_name), task="detect")
117
  device_name = device("cuda" if is_available() else "cpu")
118
 
119
+ # Preprocess annotations before the loop
120
+ preprocessed_annotations = [parse_boxes(example.get("annotations", "").strip()) for example in test_dataset]
121
+
122
+ batch_size = 16 # Define a batch size
123
+ batch_images = []
124
+ batch_annotations = []
125
+
126
  predictions = []
127
  true_labels = []
128
  pred_boxes = []
129
  true_boxes_list = [] # List of lists, each inner list contains boxes for one image
130
+
131
  logging.info(f"Inference start on device: {device_name}")
132
+
133
+ # Use torch.no_grad() to disable gradient tracking during inference
134
+ with no_grad():
135
+ for idx, example in enumerate(test_dataset):
136
+ batch_images.append(example["image"])
137
+ batch_annotations.append(preprocessed_annotations[idx])
138
+
139
+ # When the batch size is met, or it's the last image, perform inference
140
+ if (len(batch_images) == batch_size or idx == len(test_dataset) - 1):
141
+ # Make a prediction for the current batch
142
+ results = model.predict(batch_images, device=device_name, conf=THRESHOLD, verbose=False, half=True, imgsz=IMGSIZE)[0]
143
+
144
+ for batch_idx, result in enumerate(results):
145
+ annotation = batch_annotations[batch_idx]
146
+ has_smoke = len(annotation) > 0
147
+ true_labels.append(int(has_smoke))
148
+
149
+ pred_has_smoke = len(result) > 0
150
+ predictions.append(int(pred_has_smoke))
151
+
152
+ if has_smoke:
153
+ true_boxes_list.append(annotation)
154
+
155
+ # Handle prediction boxes for each image in the batch
156
+ if result.boxes.cls.numel() != 0:
157
+ pred_boxes.append(result.boxes[0].xywhn.tolist()[0])
158
+ else:
159
+ pred_boxes.append([0, 0, 0, 0])
160
+
161
+ # Clear the batch after processing
162
+ batch_images.clear()
163
+ batch_annotations.clear()
164
+
165
 
166
  #--------------------------------------------------------------------------------------------
167
  # YOUR MODEL INFERENCE STOPS HERE