File size: 14,812 Bytes
8d4bf19 97a2cd7 8d4bf19 97a2cd7 525e1e1 2fa5504 525e1e1 bf1fde7 525e1e1 2fa5504 525e1e1 bf1fde7 525e1e1 2fa5504 525e1e1 bf1fde7 525e1e1 eb1abd7 525e1e1 eb1abd7 525e1e1 eb1abd7 525e1e1 eb1abd7 8d4bf19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 |
import streamlit as st
# import config
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import numpy as np
import math
from PIL import Image
# import wandb
from model import YOLOv3
import cv2
IMAGE_SIZE = 416
ANCHORS = [
[(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
[(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
[(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
]
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
infer_transforms = A.Compose(
[
A.LongestMaxSize(max_size=IMAGE_SIZE),
A.PadIfNeeded(
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
]
)
def cells_to_bboxes(predictions, anchors, S, is_preds=True):
"""
Scales the predictions coming from the model to
be relative to the entire image such that they for example later
can be plotted or.
INPUT:
predictions: tensor of size (N, 3, S, S, num_classes+5)
anchors: the anchors used for the predictions
S: the number of cells the image is divided in on the width (and height)
is_preds: whether the input is predictions or the true bounding boxes
OUTPUT:
converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
object score, bounding box coordinates
"""
BATCH_SIZE = predictions.shape[0]
num_anchors = len(anchors)
box_predictions = predictions[..., 1:5]
if is_preds:
anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
scores = torch.sigmoid(predictions[..., 0:1])
best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
else:
scores = predictions[..., 0:1]
best_class = predictions[..., 5:6]
cell_indices = (
torch.arange(S)
.repeat(predictions.shape[0], 3, S, 1)
.unsqueeze(-1)
.to(predictions.device)
)
x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
w_h = 1 / S * box_predictions[..., 2:4]
converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
return converted_bboxes.tolist()
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
"""
Video explanation of this function:
https://youtu.be/YDkjWEN8jNA
Does Non Max Suppression given bboxes
Parameters:
bboxes (list): list of lists containing all bboxes with each bboxes
specified as [class_pred, prob_score, x1, y1, x2, y2]
iou_threshold (float): threshold where predicted bboxes is correct
threshold (float): threshold to remove predicted bboxes (independent of IoU)
box_format (str): "midpoint" or "corners" used to specify bboxes
Returns:
list: bboxes after performing NMS given a specific IoU threshold
"""
assert type(bboxes) == list
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_after_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
bboxes = [
box
for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(
torch.tensor(chosen_box[2:]),
torch.tensor(box[2:]),
box_format=box_format,
)
< iou_threshold
]
bboxes_after_nms.append(chosen_box)
return bboxes_after_nms
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint", GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
"""
Video explanation of this function:
https://youtu.be/XXYG5ZWtjj0
This function calculates intersection over union (iou) given pred boxes
and target boxes.
Parameters:
boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
Returns:
tensor: Intersection over union for all examples
"""
if box_format == "midpoint":
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
w1 = boxes_preds[..., 2:3]
h1 = boxes_preds[..., 3:4]
w2 = boxes_labels[..., 2:3]
h2 = boxes_labels[..., 3:4]
if box_format == "corners":
box1_x1 = boxes_preds[..., 0:1]
box1_y1 = boxes_preds[..., 1:2]
box1_x2 = boxes_preds[..., 2:3]
box1_y2 = boxes_preds[..., 3:4]
box2_x1 = boxes_labels[..., 0:1]
box2_y1 = boxes_labels[..., 1:2]
box2_x2 = boxes_labels[..., 2:3]
box2_y2 = boxes_labels[..., 3:4]
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
iou = intersection / (box1_area + box2_area - intersection)
if CIoU or DIoU or GIoU:
cw = box1_x2.maximum(box2_x2) - box1_x1.minimum(box2_x1) # convex (smallest enclosing box) width
ch = box1_y2.maximum(box2_y2) - box1_y1.minimum(box2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
rho2 = ((box2_x1 + box2_x2 - box1_x1 - box1_x2) ** 2 + (box2_y1 + box2_y2 - box1_y1 - box1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return iou - (c_area - intersection) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
return intersection / (box1_area + box2_area - intersection + 1e-6)
def resize_box(box, origin_dims, in_dims):
# amount of padding
h_ori, w_ori = origin_dims[0], origin_dims[1]
print(h_ori, w_ori)
padding_height = max(w_ori - h_ori, 0) * in_dims/w_ori
padding_width = max(h_ori - w_ori, 0) * in_dims/h_ori
#picture size after remove pad
h_new = in_dims - padding_height
w_new = in_dims - padding_width
# resize box
box[0] = (box[0] - padding_width//2)* w_ori/w_new
box[1] = (box[1] - padding_height//2)* h_ori/h_new
box[2] = (box[2] - padding_width//2)* w_ori/w_new
box[3] = (box[3] - padding_height//2)* h_ori/h_new
return box
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
# Rescale boxes (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding
boxes[..., :4] /= gain
clip_boxes(boxes, img0_shape)
return boxes
def clip_boxes(boxes, shape):
# Clip boxes (xyxy) to image shape (height, width)
if isinstance(boxes, torch.Tensor): # faster individually
boxes[..., 0].clamp_(0, shape[1]) # x1
boxes[..., 1].clamp_(0, shape[0]) # y1
boxes[..., 2].clamp_(0, shape[1]) # x2
boxes[..., 3].clamp_(0, shape[0]) # y2
else: # np.array (faster grouped)
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
def plot_image(image, boxes, image_ori=None):
import pickle as pkl
"""Plots predicted bounding boxes on the image"""
# cmap = plt.get_cmap("tab20b")
class_labels = [
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor"
]
colors = pkl.load(open("pallete", "rb"))
im = np.array(image)
height, width, _ = im.shape
# Draw bounding boxes on the image
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
conf = box[1]
box = box[2:]
box_clone = box.copy()
box[0] = max(box_clone[0] - box_clone[2] / 2, 0.) * width
box[1] = max(box_clone[1] - box_clone[3] / 2, 0.) * height
box[2] = min(box_clone[0] + box_clone[2] / 2, 1.) * width
box[3] = min(box_clone[1] + box_clone[3] / 2, 1.) * height
box = scale_boxes((height, width), torch.tensor(box), image_ori.shape[:2])
h_o, w_o, _ = image_ori.shape
color = colors[int(class_pred)]
# print(color)
# Draw rectangle
cv2.rectangle(image_ori, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2)
label = class_labels[int(class_pred)]
text = f"{label}: {conf:.2f}"
cv2.putText(image_ori, text, (int(box[0]), int(box[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return image_ori
# cv2.imwrite("test.png", image_ori)
def infer(model, img, thresh, iou_thresh, anchors):
model.eval()
image = np.array(img)
image = image[:, :, :3]
# print(image.shape)
image_copy = image.copy()
# image = image[np.newaxis, :]
augmentations = infer_transforms(image=image)
x = augmentations["image"]
# x = x.to("cuda")
x = torch.reshape(x, [1,3,416,416])
# print(x.shape)
with torch.no_grad():
out = model(x)
bboxes = [[] for _ in range(x.shape[0])]
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
for i in range(batch_size):
nms_boxes = non_max_suppression(
bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
img = plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes, image_copy)
return img
scene = st.radio(
"Chọn bối cảnh",
('19->20', '15->20', '10->20'))
# scene = '19->20'
# task = st.radio(
# "Chọn nhiệm vụ",
# ('task1', 'task2', 'finetune'))
all = 20
if scene == '19->20':
base = 19
new = all - base
elif scene == '15->20':
base = 15
new = all - base
else:
base = 10
new = all - base
# if task == '1.Nhiệm vụ 1':
# cls = base
# task = 'task1'
# elif task == '2. Nhiệm vụ 2 (trước tinh chỉnh)':
# cls = all
# tune = False
# else:
# cls = all
# tune = True
device = "cuda"
if not torch.cuda.is_available():
device = "cpu"
scaled_anchors = (
torch.tensor(ANCHORS)
* torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(device)
uploaded_file = st.file_uploader("Chọn hình ảnh...", type=["jpg", "jpeg", "png"])
# uploaded_file = '/home/ngocanh/Documents/final_thesis/code/dataset/10_10/base/images/test/000011.jpg'
if uploaded_file is not None:
image = Image.open(uploaded_file)
print("Thuc hien bien doi")
#task 1
file_path = f"2007_base_{base}_{new}_mAP_{base}_{new}.pth.tar"
model = YOLOv3(num_classes=base).to(device)
checkpoint = torch.load(file_path, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
image_1 = infer(model, image, 0.7, 0.8, scaled_anchors)
#task 2
image = Image.open(uploaded_file)
file_path = f"2007_task2_{base}_{new}_mAP_{base}_{new}.pth.tar"
model = YOLOv3(num_classes=all).to(device)
checkpoint = torch.load(file_path, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
image_2 = infer(model, image, 0.7, 0.8, scaled_anchors)
#ft
image = Image.open(uploaded_file)
file_path = f"2007_finetune_{base}_{new}_mAP_{base}_{new}.pth.tar"
checkpoint = torch.load(file_path, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
image_3 = infer(model, image, 0.7, 0.8, scaled_anchors)
# Streamlit App
# Widget tải lên file ảnh
# note = Image.open("note.png")
# st.image(note, width=150)
image_1 = cv2.cvtColor(image_1, cv2.COLOR_BGR2RGB)
image_2 = cv2.cvtColor(image_2, cv2.COLOR_BGR2RGB)
image_3 = cv2.cvtColor(image_3, cv2.COLOR_BGR2RGB)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.image(image, caption="Ảnh đầu vào")
with col2:
st.image(image_1, caption="Kết quả task 1", channels="BGR")
with col3:
st.image(image_2, caption="Kết quả task 2 (no finetune)", channels="BGR")
with col4:
st.image(image_3, caption="Kết quả task 2 (finetune)", channels="BGR")
# import cv2
# image_1 = cv2.cvtColor(image_1, cv2.COLOR_BGR2RGB)
# cv2.imwrite('test.jpg',image_1)
# Hiển thị ảnh gốc
# TODO: Đưa ảnh qua mô hình để xử lý (đoán, biến đổi, ...)
# Hiển thị kết quả (ảnh sau khi qua mô hình), nếu có
# Ví dụ: Nếu bạn đã có kết quả từ mô hình (processed_img) là một PIL Image
# st.image(processed_img, caption="Processed Image", use_column_width=True)
|