import torch.nn as nn import torch import numpy as np from typing import List, Tuple import torch.nn.functional as F from ..utils import box_utils from collections import namedtuple GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1']) # class SSD(nn.Module): def __init__(self, num_classes: int, base_net: nn.ModuleList, source_layer_indexes: List[int], extras: nn.ModuleList, classification_headers: nn.ModuleList, regression_headers: nn.ModuleList, is_test=False, config=None, device=None): """Compose a SSD model using the given components. """ super(SSD, self).__init__() self.num_classes = num_classes self.base_net = base_net self.source_layer_indexes = source_layer_indexes self.extras = extras self.classification_headers = classification_headers self.regression_headers = regression_headers self.is_test = is_test self.config = config # register layers in source_layer_indexes by adding them to a module list self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes if isinstance(t, tuple) and not isinstance(t, GraphPath)]) if device: self.device = device else: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if is_test: self.config = config self.priors = config.priors.to(self.device) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: confidences = [] locations = [] start_layer_index = 0 header_index = 0 for end_layer_index in self.source_layer_indexes: if isinstance(end_layer_index, GraphPath): path = end_layer_index end_layer_index = end_layer_index.s0 added_layer = None elif isinstance(end_layer_index, tuple): added_layer = end_layer_index[1] end_layer_index = end_layer_index[0] path = None else: added_layer = None path = None for layer in self.base_net[start_layer_index: end_layer_index]: x = layer(x) if added_layer: y = added_layer(x) else: y = x if path: sub = getattr(self.base_net[end_layer_index], path.name) for layer in sub[:path.s1]: x = layer(x) y = x for layer in sub[path.s1:]: x = layer(x) end_layer_index += 1 start_layer_index = end_layer_index confidence, location = self.compute_header(header_index, y) header_index += 1 confidences.append(confidence) locations.append(location) for layer in self.base_net[end_layer_index:]: x = layer(x) for layer in self.extras: x = layer(x) confidence, location = self.compute_header(header_index, x) header_index += 1 confidences.append(confidence) locations.append(location) confidences = torch.cat(confidences, 1) locations = torch.cat(locations, 1) if self.is_test: confidences = F.softmax(confidences, dim=2) boxes = box_utils.convert_locations_to_boxes( locations, self.priors, self.config.center_variance, self.config.size_variance ) boxes = box_utils.center_form_to_corner_form(boxes) return confidences, boxes else: return confidences, locations def compute_header(self, i, x): confidence = self.classification_headers[i](x) confidence = confidence.permute(0, 2, 3, 1).contiguous() confidence = confidence.view(confidence.size(0), -1, self.num_classes) location = self.regression_headers[i](x) location = location.permute(0, 2, 3, 1).contiguous() location = location.view(location.size(0), -1, 4) return confidence, location def init_from_base_net(self, model): self.base_net.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage), strict=True) self.source_layer_add_ons.apply(_xavier_init_) self.extras.apply(_xavier_init_) self.classification_headers.apply(_xavier_init_) self.regression_headers.apply(_xavier_init_) def init_from_pretrained_ssd(self, model): state_dict = torch.load(model, map_location=lambda storage, loc: storage) state_dict = {k: v for k, v in state_dict.items() if not (k.startswith("classification_headers") or k.startswith("regression_headers"))} model_dict = self.state_dict() model_dict.update(state_dict) self.load_state_dict(model_dict) self.classification_headers.apply(_xavier_init_) self.regression_headers.apply(_xavier_init_) def init(self): self.base_net.apply(_xavier_init_) self.source_layer_add_ons.apply(_xavier_init_) self.extras.apply(_xavier_init_) self.classification_headers.apply(_xavier_init_) self.regression_headers.apply(_xavier_init_) def load(self, model): self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage)) def save(self, model_path): torch.save(self.state_dict(), model_path) class MatchPrior(object): def __init__(self, center_form_priors, center_variance, size_variance, iou_threshold): self.center_form_priors = center_form_priors self.corner_form_priors = box_utils.center_form_to_corner_form(center_form_priors) self.center_variance = center_variance self.size_variance = size_variance self.iou_threshold = iou_threshold def __call__(self, gt_boxes, gt_labels): if type(gt_boxes) is np.ndarray: gt_boxes = torch.from_numpy(gt_boxes) if type(gt_labels) is np.ndarray: gt_labels = torch.from_numpy(gt_labels) boxes, labels = box_utils.assign_priors(gt_boxes, gt_labels, self.corner_form_priors, self.iou_threshold) boxes = box_utils.corner_form_to_center_form(boxes) locations = box_utils.convert_boxes_to_locations(boxes, self.center_form_priors, self.center_variance, self.size_variance) return locations, labels def _xavier_init_(m: nn.Module): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight)