File size: 1,669 Bytes
b6c45cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# Models
# from .conv_tasnet import ConvTasNet
# from .dccrnet import DCCRNet
# from .dcunet import DCUNet
# from .dprnn_tasnet import DPRNNTasNet
# from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet
from .dptnet import DPTNet
# from .lstm_tasnet import LSTMTasNet
# from .demask import DeMask
# Sharing-related
# from .publisher import save_publishable, upload_publishable
__all__ = [
# "ConvTasNet",
# "DPRNNTasNet",
# "SuDORMRFImprovedNet",
# "SuDORMRFNet",
"DPTNet",
# "LSTMTasNet",
# "DeMask",
# "DCUNet",
# "DCCRNet",
# "save_publishable",
# "upload_publishable",
]
def register_model(custom_model):
"""Register a custom model, gettable with `models.get`.
Args:
custom_model: Custom model to register.
"""
if (
custom_model.__name__ in globals().keys()
or custom_model.__name__.lower() in globals().keys()
):
raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.")
globals().update({custom_model.__name__: custom_model})
def get(identifier):
"""Returns an model class from a string (case-insensitive).
Args:
identifier (str): the model name.
Returns:
:class:`torch.nn.Module`
"""
if isinstance(identifier, str):
to_get = {k.lower(): v for k, v in globals().items()}
cls = to_get.get(identifier.lower())
if cls is None:
raise ValueError(f"Could not interpret model name : {str(identifier)}")
return cls
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|