caleb2's picture
initial commit
d68c650
import tensorflow as tf
from tensorflow.keras.layers import Input, Multiply, Add
from tensorflow.keras.models import Model
from .stnn_layers import TTL, SoftmaxEmbeddingLayer
def build_stnn(config):
"""
Constructs a Stacked Tensorial Neural Network (STNN) as a TensorFlow model based on
the provided configuration dictionary.
Args:
config (dict): Configuration dictionary for the STNN model. Must contain the following entries:
- 'K' (int): Number of tensor networks to be stacked
- 'd' (int): Number of dense layers in the model's SoftmaxEmbeddingLayer.
- 'nx1', 'nx2', 'nx3' (int): Dimensions of the finite-difference grid
- All other required entries for the TTL class, not already listed above.
Returns:
tf.keras.Model: The constructed STNN model.
Raises:
ValueError: If the config dictionary does not contain positive integers 'K', 'd', 'nx2', 'nx3';
also if config['nx3'] is not divisible by 2.
"""
required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W']
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
if not isinstance(config[key], int):
raise TypeError(f"{key} must be an integer.")
for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
if config[key] <= 0:
raise ValueError(f"{key} must be positive.")
if config['nx3'] % 2 == 1:
raise ValueError('Config error: nx3 must be divisible by 2.')
K = config['K'] # Number of tensor networks
d = config['d'] # Number of dense layers in SoftmaxEmbeddingLayer
input_shape = (2 * config['nx2'], config['nx3'] // 2, 1)
input_tensor = Input(shape = input_shape)
# Process parameter array (ell, a1, a2) and output weights for stacking the tensor networks
preprocess_layer = SoftmaxEmbeddingLayer(K, d)
params_input = Input(shape = (3,))
stack_weights = preprocess_layer(params_input)[:, tf.newaxis, tf.newaxis, :]
# Build the tensor networks using the custom keras layer class TLL
models = [TTL(config) for _ in range(K)]
# Combine the tensor networks based on the weights outputted by 'preprocess_layer'
weighted_outputs = []
for i, model in enumerate(models):
processed_output = model(input_tensor)
weighted_output = Multiply()([processed_output, stack_weights[..., i]])
weighted_outputs.append(weighted_output)
final_output = Add()(weighted_outputs)
model = Model(inputs = [params_input, input_tensor], outputs = final_output)
return model