File size: 2,551 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import tensorflow as tf
from tensorflow.keras.layers import Input, Multiply, Add
from tensorflow.keras.models import Model

from .stnn_layers import TTL, SoftmaxEmbeddingLayer


def build_stnn(config):
	"""
	Constructs a Stacked Tensorial Neural Network (STNN) as a TensorFlow model based on
	the provided configuration dictionary.

	Args:
		config (dict): Configuration dictionary for the STNN model. Must contain the following entries:
						- 'K' (int): Number of tensor networks to be stacked
						- 'd' (int): Number of dense layers in the model's SoftmaxEmbeddingLayer.
						- 'nx1', 'nx2', 'nx3' (int): Dimensions of the finite-difference grid
						- All other required entries for the TTL class, not already listed above.

	Returns:
		tf.keras.Model: The constructed STNN model.

	Raises:
		ValueError: If the config dictionary does not contain positive integers 'K', 'd', 'nx2', 'nx3';
					also if config['nx3'] is not divisible by 2.
	"""
	required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W']
	missing_keys = [key for key in required_keys if key not in config]
	if missing_keys:
		raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")

	for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
		if not isinstance(config[key], int):
			raise TypeError(f"{key} must be an integer.")

	for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
		if config[key] <= 0:
			raise ValueError(f"{key} must be positive.")

	if config['nx3'] % 2 == 1:
		raise ValueError('Config error: nx3 must be divisible by 2.')

	K = config['K'] # Number of tensor networks
	d = config['d'] # Number of dense layers in SoftmaxEmbeddingLayer
	input_shape = (2 * config['nx2'], config['nx3'] // 2, 1)
	input_tensor = Input(shape = input_shape)

	# Process parameter array (ell, a1, a2) and output weights for stacking the tensor networks
	preprocess_layer = SoftmaxEmbeddingLayer(K, d)
	params_input = Input(shape = (3,))
	stack_weights = preprocess_layer(params_input)[:, tf.newaxis, tf.newaxis, :]

	# Build the tensor networks using the custom keras layer class TLL
	models = [TTL(config) for _ in range(K)]

	# Combine the tensor networks based on the weights outputted by 'preprocess_layer'
	weighted_outputs = []
	for i, model in enumerate(models):
		processed_output = model(input_tensor)
		weighted_output = Multiply()([processed_output, stack_weights[..., i]])
		weighted_outputs.append(weighted_output)
	final_output = Add()(weighted_outputs)

	model = Model(inputs = [params_input, input_tensor], outputs = final_output)

	return model