import tensorflow as tf from tensorflow.keras.layers import Input, Multiply, Add from tensorflow.keras.models import Model from .stnn_layers import TTL, SoftmaxEmbeddingLayer def build_stnn(config): """ Constructs a Stacked Tensorial Neural Network (STNN) as a TensorFlow model based on the provided configuration dictionary. Args: config (dict): Configuration dictionary for the STNN model. Must contain the following entries: - 'K' (int): Number of tensor networks to be stacked - 'd' (int): Number of dense layers in the model's SoftmaxEmbeddingLayer. - 'nx1', 'nx2', 'nx3' (int): Dimensions of the finite-difference grid - All other required entries for the TTL class, not already listed above. Returns: tf.keras.Model: The constructed STNN model. Raises: ValueError: If the config dictionary does not contain positive integers 'K', 'd', 'nx2', 'nx3'; also if config['nx3'] is not divisible by 2. """ required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W'] missing_keys = [key for key in required_keys if key not in config] if missing_keys: raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}") for key in ['nx1', 'nx2', 'nx3', 'K', 'd']: if not isinstance(config[key], int): raise TypeError(f"{key} must be an integer.") for key in ['nx1', 'nx2', 'nx3', 'K', 'd']: if config[key] <= 0: raise ValueError(f"{key} must be positive.") if config['nx3'] % 2 == 1: raise ValueError('Config error: nx3 must be divisible by 2.') K = config['K'] # Number of tensor networks d = config['d'] # Number of dense layers in SoftmaxEmbeddingLayer input_shape = (2 * config['nx2'], config['nx3'] // 2, 1) input_tensor = Input(shape = input_shape) # Process parameter array (ell, a1, a2) and output weights for stacking the tensor networks preprocess_layer = SoftmaxEmbeddingLayer(K, d) params_input = Input(shape = (3,)) stack_weights = preprocess_layer(params_input)[:, tf.newaxis, tf.newaxis, :] # Build the tensor networks using the custom keras layer class TLL models = [TTL(config) for _ in range(K)] # Combine the tensor networks based on the weights outputted by 'preprocess_layer' weighted_outputs = [] for i, model in enumerate(models): processed_output = model(input_tensor) weighted_output = Multiply()([processed_output, stack_weights[..., i]]) weighted_outputs.append(weighted_output) final_output = Add()(weighted_outputs) model = Model(inputs = [params_input, input_tensor], outputs = final_output) return model