File size: 2,127 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.



https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583

"""

import torch
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

from ...core import register
from .utils import IntermediateLayerGetter


@register()
class TimmModel(torch.nn.Module):
    def __init__(

        self, name, return_layers, pretrained=False, exportable=True, features_only=True, **kwargs

    ) -> None:
        super().__init__()

        import timm

        model = timm.create_model(
            name,
            pretrained=pretrained,
            exportable=exportable,
            features_only=features_only,
            **kwargs,
        )
        # nodes, _ = get_graph_node_names(model)
        # print(nodes)
        # features = {'': ''}
        # model = create_feature_extractor(model, return_nodes=features)

        assert set(return_layers).issubset(
            model.feature_info.module_name()
        ), f"return_layers should be a subset of {model.feature_info.module_name()}"

        # self.model = model
        self.model = IntermediateLayerGetter(model, return_layers)

        return_idx = [model.feature_info.module_name().index(name) for name in return_layers]
        self.strides = [model.feature_info.reduction()[i] for i in return_idx]
        self.channels = [model.feature_info.channels()[i] for i in return_idx]
        self.return_idx = return_idx
        self.return_layers = return_layers

    def forward(self, x: torch.Tensor):
        outputs = self.model(x)
        # outputs = [outputs[i] for i in self.return_idx]
        return outputs


if __name__ == "__main__":
    model = TimmModel(name="resnet34", return_layers=["layer2", "layer3"])
    data = torch.rand(1, 3, 640, 640)
    outputs = model(data)

    for output in outputs:
        print(output.shape)

    """

    model:

        type: TimmModel

        name: resnet34

        return_layers: ['layer2', 'layer4']

    """