NeMo / tests /core /test_serialization.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from omegaconf import DictConfig
from nemo.collections.asr.models import EncDecCTCModel
from nemo.collections.asr.modules import SpectrogramAugmentation
from nemo.core.classes.common import Serialization
def get_class_path(cls):
return f"{cls.__module__}.{cls.__name__}"
class MockSerializationImpl(Serialization):
def __init__(self, cfg: DictConfig):
self.cfg = cfg
self.value = self.__class__.__name__
class MockSerializationImplV2(MockSerializationImpl):
pass
class TestSerialization:
@pytest.mark.unit
def test_from_config_dict_with_cls(self):
"""Here we test that instantiation works for configs with cls class path in them.
Note that just Serialization.from_config_dict can be used to create an object"""
config = DictConfig(
{
'cls': 'nemo.collections.asr.modules.SpectrogramAugmentation',
'params': {'rect_freq': 50, 'rect_masks': 5, 'rect_time': 120,},
}
)
obj = Serialization.from_config_dict(config=config)
assert isinstance(obj, SpectrogramAugmentation)
@pytest.mark.unit
def test_from_config_dict_without_cls(self):
"""Here we test that instantiation works for configs without cls class path in them.
IMPORTANT: in this case, correct class type should call from_config_dict. This should work for Models."""
preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
encoder = {
'cls': 'nemo.collections.asr.modules.ConvASREncoder',
'params': {
'feat_in': 64,
'activation': 'relu',
'conv_mask': True,
'jasper': [
{
'filters': 1024,
'repeat': 1,
'kernel': [1],
'stride': [1],
'dilation': [1],
'dropout': 0.0,
'residual': False,
'separable': True,
'se': True,
'se_context_size': -1,
}
],
},
}
decoder = {
'cls': 'nemo.collections.asr.modules.ConvASRDecoder',
'params': {
'feat_in': 1024,
'num_classes': 28,
'vocabulary': [
' ',
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o',
'p',
'q',
'r',
's',
't',
'u',
'v',
'w',
'x',
'y',
'z',
"'",
],
},
}
modelConfig = DictConfig(
{'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder)}
)
obj = EncDecCTCModel.from_config_dict(config=modelConfig)
assert isinstance(obj, EncDecCTCModel)
@pytest.mark.unit
def test_config_updated(self):
config = DictConfig(
{
'cls': 'nemo.collections.asr.modules.SpectrogramAugmentation',
'params': {'rect_freq': 50, 'rect_masks': 5, 'rect_time': 120,},
}
)
obj = Serialization.from_config_dict(config=config)
new_config = obj.to_config_dict()
assert config != new_config
assert 'params' not in new_config
assert 'cls' not in new_config
assert '_target_' in new_config
@pytest.mark.unit
def test_base_class_instantiation(self):
# Target class is V2 impl, calling class is Serialization (base class)
config = DictConfig({'target': get_class_path(MockSerializationImplV2)})
obj = Serialization.from_config_dict(config=config)
new_config = obj.to_config_dict()
assert config == new_config
assert isinstance(obj, MockSerializationImplV2)
assert obj.value == "MockSerializationImplV2"
@pytest.mark.unit
def test_self_class_instantiation(self):
# Target class is V1 impl, calling class is V1 (same class)
config = DictConfig({'target': get_class_path(MockSerializationImpl)})
obj = MockSerializationImpl.from_config_dict(config=config) # Serialization is base class
new_config = obj.to_config_dict()
assert config == new_config
assert isinstance(obj, MockSerializationImpl)
assert obj.value == "MockSerializationImpl"
@pytest.mark.unit
def test_sub_class_instantiation(self):
# Target class is V1 impl, calling class is V2 (sub class)
config = DictConfig({'target': get_class_path(MockSerializationImpl)})
obj = MockSerializationImplV2.from_config_dict(config=config) # Serialization is base class
new_config = obj.to_config_dict()
assert config == new_config
assert isinstance(obj, MockSerializationImplV2)
assert obj.value == "MockSerializationImplV2"