Spaces:
Sleeping
Sleeping
import unittest | |
import numpy as np | |
import copy | |
from stnn.nn.stnn import build_stnn | |
class TestBuildSTNN(unittest.TestCase): | |
def setUp(self): | |
self.config = { | |
'K': 1, | |
'nx1': 8, | |
'nx2': 8, | |
'nx3': 8, | |
'd': 8, | |
'W': 3, | |
'shape1': [1, 2, 3], | |
'shape2': [2, 2, 2], | |
'ranks': [1, 2, 2, 1], | |
} | |
self.saved_config = copy.deepcopy(self.config) | |
self._required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W'] | |
self._optional_keys = ['use_regularization', 'regularization_strength'] | |
def test_missing_keys(self): | |
for key in self._required_keys: | |
del self.config[key] | |
with self.assertRaises(KeyError): | |
build_stnn(self.config) | |
self.config[key] = self.saved_config[key] | |
def test_invalid_values(self): | |
for key in ['K', 'd', 'W', 'nx1', 'nx2', 'nx3']: | |
for value in [1.5, 'a', None, np.nan]: | |
with self.subTest(value = value): | |
self.config[key] = value | |
with self.assertRaises(TypeError): | |
build_stnn(self.config) | |
self.config[key] = self.saved_config[key] | |
value = -1 | |
with self.subTest(value = value): | |
self.config[key] = value | |
with self.assertRaises(ValueError): | |
build_stnn(self.config) | |
self.config[key] = self.saved_config[key] | |
self.config['nx3'] = 7 # not divisible by 2 | |
with self.assertRaises(ValueError): | |
build_stnn(self.config) | |
self.config[key] = self.saved_config[key] | |
def test_positive_values(self): | |
for key in ['K', 'nx1', 'nx2', 'nx3', 'd']: | |
self.config[key] = 0 | |
with self.assertRaises(ValueError): | |
build_stnn(self.config) | |
self.config[key] = self.saved_config[key] | |
if __name__ == '__main__': | |
unittest.main() | |