File size: 1,646 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import unittest
import numpy as np
import copy
from stnn.nn.stnn import build_stnn


class TestBuildSTNN(unittest.TestCase):

	def setUp(self):
		self.config = {
			'K': 1,
			'nx1': 8,
			'nx2': 8,
			'nx3': 8,
			'd': 8,
			'W': 3,
			'shape1': [1, 2, 3],
			'shape2': [2, 2, 2],
			'ranks': [1, 2, 2, 1],
		}
		self.saved_config = copy.deepcopy(self.config)
		self._required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W']
		self._optional_keys = ['use_regularization', 'regularization_strength']

	def test_missing_keys(self):
		for key in self._required_keys:
			del self.config[key]
			with self.assertRaises(KeyError):
				build_stnn(self.config)
			self.config[key] = self.saved_config[key]

	def test_invalid_values(self):
		for key in ['K', 'd', 'W', 'nx1', 'nx2', 'nx3']:
			for value in [1.5, 'a', None, np.nan]:
				with self.subTest(value = value):
					self.config[key] = value
					with self.assertRaises(TypeError):
						build_stnn(self.config)
					self.config[key] = self.saved_config[key]
			value = -1
			with self.subTest(value = value):
				self.config[key] = value
				with self.assertRaises(ValueError):
					build_stnn(self.config)
				self.config[key] = self.saved_config[key]

		self.config['nx3'] = 7  # not divisible by 2
		with self.assertRaises(ValueError):
			build_stnn(self.config)
		self.config[key] = self.saved_config[key]

	def test_positive_values(self):
		for key in ['K', 'nx1', 'nx2', 'nx3', 'd']:
			self.config[key] = 0
			with self.assertRaises(ValueError):
				build_stnn(self.config)
			self.config[key] = self.saved_config[key]


if __name__ == '__main__':
	unittest.main()