SSD-Detection / vision /ssd /predictor.py
Kaori1707's picture
update application
5eff22e
import torch
from ..utils import box_utils
from .data_preprocessing import PredictionTransform
from ..utils.misc import Timer
class Predictor:
def __init__(self, net, size, mean=0.0, std=1.0, nms_method=None,
iou_threshold=0.45, filter_threshold=0.01, candidate_size=200, sigma=0.5, device=None):
self.net = net
self.transform = PredictionTransform(size, mean, std)
self.iou_threshold = iou_threshold
self.filter_threshold = filter_threshold
self.candidate_size = candidate_size
self.nms_method = nms_method
self.sigma = sigma
if device:
self.device = device
else:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)
self.net.eval()
self.timer = Timer()
def predict(self, image, top_k=-1, prob_threshold=None):
cpu_device = torch.device("cpu")
height, width, _ = image.shape
image = self.transform(image)
#print(image)
images = image.unsqueeze(0)
images = images.to(self.device)
with torch.no_grad():
self.timer.start()
scores, boxes = self.net.forward(images)
print("Inference time: ", self.timer.end())
boxes = boxes[0]
scores = scores[0]
if not prob_threshold:
prob_threshold = self.filter_threshold
# this version of nms is slower on GPU, so we move data to CPU.
boxes = boxes.to(cpu_device)
scores = scores.to(cpu_device)
picked_box_probs = []
picked_labels = []
for class_index in range(1, scores.size(1)):
probs = scores[:, class_index]
mask = probs > prob_threshold
probs = probs[mask]
if probs.size(0) == 0:
continue
subset_boxes = boxes[mask, :]
box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1)
box_probs = box_utils.nms(box_probs, self.nms_method,
score_threshold=prob_threshold,
iou_threshold=self.iou_threshold,
sigma=self.sigma,
top_k=top_k,
candidate_size=self.candidate_size)
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.size(0))
if not picked_box_probs:
return torch.tensor([]), torch.tensor([]), torch.tensor([])
picked_box_probs = torch.cat(picked_box_probs)
picked_box_probs[:, 0] *= width
picked_box_probs[:, 1] *= height
picked_box_probs[:, 2] *= width
picked_box_probs[:, 3] *= height
return picked_box_probs[:, :4], torch.tensor(picked_labels), picked_box_probs[:, 4]