File size: 1,150 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import unittest
import os
from stnn.utils.stats import get_stats


class TestGetStats(unittest.TestCase):

	def setUp(self):
		self.rho = np.array([[1, 2], [3, 4]])
		self.rho_pred = np.array([[1, 2], [3, 4]])
		self.filename = 'test_stats.npz'

	def test_correctness(self):
		get_stats(self.rho, self.rho_pred, self.filename)
		with np.load(self.filename) as data:
			self.assertAlmostEqual(data['max_loss'], 0.0, places=5)
			self.assertEqual(data['avg_loss'], 0.0)
			self.assertEqual(data['N'], self.rho.shape[0])

	def test_file_creation(self):
		get_stats(self.rho, self.rho_pred, self.filename)
		self.assertTrue(os.path.exists(self.filename))

	def test_file_content(self):
		get_stats(self.rho, self.rho_pred, self.filename)
		with np.load(self.filename) as data:
			self.assertIn('max_loss', data)
			self.assertIn('avg_loss', data)
			self.assertIn('N', data)

	def test_invalid_input(self):
		with self.assertRaises(ValueError):
			get_stats(np.array([1, 2]), np.array([[1, 2], [3, 4]]))

	def tearDown(self):
		if os.path.exists(self.filename):
			os.remove(self.filename)

if __name__ == '__main__':
	unittest.main()