D-FINE / src /nn /backbone /timm_model.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583
"""
import torch
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from ...core import register
from .utils import IntermediateLayerGetter
@register()
class TimmModel(torch.nn.Module):
def __init__(
self, name, return_layers, pretrained=False, exportable=True, features_only=True, **kwargs
) -> None:
super().__init__()
import timm
model = timm.create_model(
name,
pretrained=pretrained,
exportable=exportable,
features_only=features_only,
**kwargs,
)
# nodes, _ = get_graph_node_names(model)
# print(nodes)
# features = {'': ''}
# model = create_feature_extractor(model, return_nodes=features)
assert set(return_layers).issubset(
model.feature_info.module_name()
), f"return_layers should be a subset of {model.feature_info.module_name()}"
# self.model = model
self.model = IntermediateLayerGetter(model, return_layers)
return_idx = [model.feature_info.module_name().index(name) for name in return_layers]
self.strides = [model.feature_info.reduction()[i] for i in return_idx]
self.channels = [model.feature_info.channels()[i] for i in return_idx]
self.return_idx = return_idx
self.return_layers = return_layers
def forward(self, x: torch.Tensor):
outputs = self.model(x)
# outputs = [outputs[i] for i in self.return_idx]
return outputs
if __name__ == "__main__":
model = TimmModel(name="resnet34", return_layers=["layer2", "layer3"])
data = torch.rand(1, 3, 640, 640)
outputs = model(data)
for output in outputs:
print(output.shape)
"""
model:
type: TimmModel
name: resnet34
return_layers: ['layer2', 'layer4']
"""