Spaces:
Running
on
Zero
Running
on
Zero
""" | |
https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
from collections import OrderedDict | |
from typing import Dict, List | |
import torch.nn as nn | |
class IntermediateLayerGetter(nn.ModuleDict): | |
""" | |
Module wrapper that returns intermediate layers from a model | |
It has a strong assumption that the modules have been registered | |
into the model in the same order as they are used. | |
This means that one should **not** reuse the same nn.Module | |
twice in the forward if you want this to work. | |
Additionally, it is only able to query submodules that are directly | |
assigned to the model. So if `model` is passed, `model.feature1` can | |
be returned, but not `model.feature1.layer2`. | |
""" | |
_version = 3 | |
def __init__(self, model: nn.Module, return_layers: List[str]) -> None: | |
if not set(return_layers).issubset([name for name, _ in model.named_children()]): | |
raise ValueError( | |
"return_layers are not present in model. {}".format( | |
[name for name, _ in model.named_children()] | |
) | |
) | |
orig_return_layers = return_layers | |
return_layers = {str(k): str(k) for k in return_layers} | |
layers = OrderedDict() | |
for name, module in model.named_children(): | |
layers[name] = module | |
if name in return_layers: | |
del return_layers[name] | |
if not return_layers: | |
break | |
super().__init__(layers) | |
self.return_layers = orig_return_layers | |
def forward(self, x): | |
outputs = [] | |
for name, module in self.items(): | |
x = module(x) | |
if name in self.return_layers: | |
outputs.append(x) | |
return outputs | |