DeepLearning101's picture
Upload 16 files
b6c45cb verified
raw
history blame contribute delete
1.67 kB
# Models
# from .conv_tasnet import ConvTasNet
# from .dccrnet import DCCRNet
# from .dcunet import DCUNet
# from .dprnn_tasnet import DPRNNTasNet
# from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet
from .dptnet import DPTNet
# from .lstm_tasnet import LSTMTasNet
# from .demask import DeMask
# Sharing-related
# from .publisher import save_publishable, upload_publishable
__all__ = [
# "ConvTasNet",
# "DPRNNTasNet",
# "SuDORMRFImprovedNet",
# "SuDORMRFNet",
"DPTNet",
# "LSTMTasNet",
# "DeMask",
# "DCUNet",
# "DCCRNet",
# "save_publishable",
# "upload_publishable",
]
def register_model(custom_model):
"""Register a custom model, gettable with `models.get`.
Args:
custom_model: Custom model to register.
"""
if (
custom_model.__name__ in globals().keys()
or custom_model.__name__.lower() in globals().keys()
):
raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.")
globals().update({custom_model.__name__: custom_model})
def get(identifier):
"""Returns an model class from a string (case-insensitive).
Args:
identifier (str): the model name.
Returns:
:class:`torch.nn.Module`
"""
if isinstance(identifier, str):
to_get = {k.lower(): v for k, v in globals().items()}
cls = to_get.get(identifier.lower())
if cls is None:
raise ValueError(f"Could not interpret model name : {str(identifier)}")
return cls
raise ValueError(f"Could not interpret model name : {str(identifier)}")