File size: 1,669 Bytes
b6c45cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Models
# from .conv_tasnet import ConvTasNet
# from .dccrnet import DCCRNet
# from .dcunet import DCUNet
# from .dprnn_tasnet import DPRNNTasNet
# from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet
from .dptnet import DPTNet
# from .lstm_tasnet import LSTMTasNet
# from .demask import DeMask

# Sharing-related
# from .publisher import save_publishable, upload_publishable

__all__ = [
    # "ConvTasNet",
    # "DPRNNTasNet",
    # "SuDORMRFImprovedNet",
    # "SuDORMRFNet",
    "DPTNet",
    # "LSTMTasNet",
    # "DeMask",
    # "DCUNet",
    # "DCCRNet",
    # "save_publishable",
    # "upload_publishable",
]


def register_model(custom_model):
    """Register a custom model, gettable with `models.get`.



    Args:

        custom_model: Custom model to register.



    """
    if (
        custom_model.__name__ in globals().keys()
        or custom_model.__name__.lower() in globals().keys()
    ):
        raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.")
    globals().update({custom_model.__name__: custom_model})


def get(identifier):
    """Returns an model class from a string (case-insensitive).



    Args:

        identifier (str): the model name.



    Returns:

        :class:`torch.nn.Module`

    """
    if isinstance(identifier, str):
        to_get = {k.lower(): v for k, v in globals().items()}
        cls = to_get.get(identifier.lower())
        if cls is None:
            raise ValueError(f"Could not interpret model name : {str(identifier)}")
        return cls
    raise ValueError(f"Could not interpret model name : {str(identifier)}")