Spaces:
Runtime error
Runtime error
File size: 2,907 Bytes
5eff22e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from ..utils import box_utils
from .data_preprocessing import PredictionTransform
from ..utils.misc import Timer
class Predictor:
def __init__(self, net, size, mean=0.0, std=1.0, nms_method=None,
iou_threshold=0.45, filter_threshold=0.01, candidate_size=200, sigma=0.5, device=None):
self.net = net
self.transform = PredictionTransform(size, mean, std)
self.iou_threshold = iou_threshold
self.filter_threshold = filter_threshold
self.candidate_size = candidate_size
self.nms_method = nms_method
self.sigma = sigma
if device:
self.device = device
else:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)
self.net.eval()
self.timer = Timer()
def predict(self, image, top_k=-1, prob_threshold=None):
cpu_device = torch.device("cpu")
height, width, _ = image.shape
image = self.transform(image)
#print(image)
images = image.unsqueeze(0)
images = images.to(self.device)
with torch.no_grad():
self.timer.start()
scores, boxes = self.net.forward(images)
print("Inference time: ", self.timer.end())
boxes = boxes[0]
scores = scores[0]
if not prob_threshold:
prob_threshold = self.filter_threshold
# this version of nms is slower on GPU, so we move data to CPU.
boxes = boxes.to(cpu_device)
scores = scores.to(cpu_device)
picked_box_probs = []
picked_labels = []
for class_index in range(1, scores.size(1)):
probs = scores[:, class_index]
mask = probs > prob_threshold
probs = probs[mask]
if probs.size(0) == 0:
continue
subset_boxes = boxes[mask, :]
box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1)
box_probs = box_utils.nms(box_probs, self.nms_method,
score_threshold=prob_threshold,
iou_threshold=self.iou_threshold,
sigma=self.sigma,
top_k=top_k,
candidate_size=self.candidate_size)
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.size(0))
if not picked_box_probs:
return torch.tensor([]), torch.tensor([]), torch.tensor([])
picked_box_probs = torch.cat(picked_box_probs)
picked_box_probs[:, 0] *= width
picked_box_probs[:, 1] *= height
picked_box_probs[:, 2] *= width
picked_box_probs[:, 3] *= height
return picked_box_probs[:, :4], torch.tensor(picked_labels), picked_box_probs[:, 4]
|