File size: 5,453 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import unittest
import numpy as np
import h5py
import tempfile
from stnn.data.preprocessing import get_data_from_file, load_data, load_training_data


class TestGetDataFromFile(unittest.TestCase):

	def setUp(self):
		self.temp_file = tempfile.NamedTemporaryFile(delete = False)
		self.nx1, self.nx2, self.nx3 = 30, 20, 16
		self.Nsamples = 10
		with h5py.File(self.temp_file.name, 'w') as f:
			f.create_dataset('ell', data = np.random.rand(self.Nsamples))
			f.create_dataset('a1', data = np.random.rand(self.Nsamples))
			f.create_dataset('a2', data = np.random.rand(self.Nsamples))
			f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
			f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
			f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))

		self.temp_file1 = tempfile.NamedTemporaryFile(delete = False)
		with h5py.File(self.temp_file1.name, 'w') as f:
			f.create_dataset('ell', data = np.random.rand(self.Nsamples))
			f.create_dataset('a1', data = np.random.rand(self.Nsamples))
			f.create_dataset('a2', data = np.random.rand(self.Nsamples))
			f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
			f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
			f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))

		self.bad_file = tempfile.NamedTemporaryFile(delete = False)
		with h5py.File(self.bad_file.name, 'w') as f:
			f.create_dataset('ell', data = np.random.rand(self.Nsamples))
			f.create_dataset('a1', data = np.random.rand(self.Nsamples))
			f.create_dataset('a2', data = np.random.rand(self.Nsamples))
			f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))

	def tearDown(self):
		self.temp_file.close()
		self.temp_file1.close()
		self.bad_file.close()

	def test_missing_datasets(self):
		with self.assertRaises(ValueError):
			get_data_from_file(self.bad_file.name, self.nx2, self.nx2)

	def test_data_extraction_shapes(self):
		result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3)
		self.assertEqual(result[0].shape, (self.Nsamples,))
		self.assertEqual(result[1].shape, (self.Nsamples,))
		self.assertEqual(result[2].shape, (self.Nsamples,))
		self.assertEqual(result[3].shape, (self.Nsamples, 2 * self.nx2, self.nx3 // 2))
		self.assertEqual(result[4].shape, (self.Nsamples, self.nx1, self.nx2))

	def test_nrange_parameter(self):
		Nrange = (2, 5)
		result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
		expected_size = Nrange[1] - Nrange[0]
		self.assertEqual(result[0].shape, (expected_size,))
		self.assertEqual(result[0].shape, (expected_size,))

	def test_list_input(self):
		file_list = [self.temp_file.name, self.temp_file1.name]
		Nrange_list = [(0, -1), (0, -1)]
		with self.assertRaises(TypeError):
			# noinspection PyTypeChecker
			_ = get_data_from_file(file_list, self.nx2, self.nx3, Nrange = Nrange_list)

	def test_invalid_Nrange(self):
		Nrange_list = [(0, -1), (0, -1)]
		with self.assertRaises(TypeError):
			_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange_list)

		for Nrange in [(0, 1, 1), 1, (1), (1.5, 3), (3, 1.5), (1.5, 1.5), 'x']:
			with self.assertRaises(TypeError):
				_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
			with self.assertRaises(TypeError):
				_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = list(Nrange))

	def test_good_data_load(self):
		files = [self.temp_file.name, self.temp_file1.name]
		Nrange_list = [(0, None), (0, self.Nsamples)]
		ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0

		params, bf, rho = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
		self.assertEqual(params.shape, (2 * self.Nsamples, 3))
		self.assertEqual(bf.shape, (2 * self.Nsamples, 2 * self.nx2, self.nx3 // 2))
		self.assertEqual(rho.shape, (2 * self.Nsamples, self.nx1, self.nx2))

		test_size = 0.3
		(params_train, bf_train, rho_train,
		 params_test, bf_test, rho_test) = load_training_data(files, self.nx2, self.nx3,
															  ell1, ell2, a1, a2, test_size = test_size,
															  Nrange_list = Nrange_list)
		Ntest = int(test_size * 2 * self.Nsamples)
		Ntrain = 2 * self.Nsamples - Ntest
		self.assertEqual(params_train.shape, (Ntrain, 3))
		self.assertEqual(bf_train.shape, (Ntrain, 2 * self.nx2, self.nx3 // 2))
		self.assertEqual(rho_train.shape, (Ntrain, self.nx1, self.nx2))
		self.assertEqual(params_test.shape, (Ntest, 3))
		self.assertEqual(bf_test.shape, (Ntest, 2 * self.nx2, self.nx3 // 2))
		self.assertEqual(rho_test.shape, (Ntest, self.nx1, self.nx2))

	def test_bad_data_load(self):
		files = [self.temp_file.name, self.temp_file1.name]
		Nrange_list = (0, -1)
		ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0
		with self.assertRaises(TypeError):
			_ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
		with self.assertRaises(TypeError):
			_ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = list(Nrange_list))

		Nrange_list = [(0, -1), (0, -1)]
		for test_size in [-1, 0.0, 1.5]:
			with self.assertRaises(ValueError):
				_ = load_training_data(files, self.nx2, self.nx3,
									   ell1, ell2, a1, a2, test_size = test_size, Nrange_list = Nrange_list)


if __name__ == '__main__':
	unittest.main()