stacked_tensorial_nn / stnn /nn /stnn_layers.py
caleb2's picture
initial commit
d68c650
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Reshape, Flatten
import t3f
import os
import logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
logging.getLogger('tensorflow').setLevel(logging.FATAL)
class SoftmaxEmbeddingLayer(tf.keras.layers.Layer):
"""
Parameter embedding layer that generates the weights used for stacking the tensor networks. It
takes the parameter array, lambda = (ell, a1, a2), as input and outputs K numbers that sum to 1.
Attributes:
output_dim (int): The dimension of the output
expansion_dim (int): The dimension used for expanding the input in intermediate layers.
"""
def __init__(self, output_dim, d, expansion_dim = 30, **kwargs):
super(SoftmaxEmbeddingLayer, self).__init__(**kwargs)
self.reduction_layer = None
self.expansion_layers = None
self.output_dim = output_dim
self.expansion_dim = expansion_dim
self.d = d # Number of dense layers
def build(self, input_shape):
# Expansion layers to increase dimensionality
self.expansion_layers = [tf.keras.layers.Dense(self.expansion_dim, activation = 'relu') for _ in range(self.d)]
# Reduction layer to bring dimensionality back to the desired output dimension
self.reduction_layer = tf.keras.layers.Dense(self.output_dim)
def call(self, inputs):
expanded = inputs
for layer in self.expansion_layers:
expanded = layer(expanded)
return tf.nn.softmax(self.reduction_layer(expanded))
def get_config(self):
return {'output_dim': self.output_dim, 'expansion_dim': self.expansion_dim}
class EinsumTTLRegularizer(tf.keras.regularizers.Regularizer):
"""
Regularizer for the Einsum layer of the TTL layer class, penalizing high-frequency components of the
weights vector.
Attributes:
strength (float): The regularization strength.
midpoint (int): Index demarcating the inner and outer boundaries, i.e. x[:midpoint] contains
data for the inner boundary, and x[midpoint:] contains data for the outer boundary.
The regularization is designed so it does not penalize variations across this index.
"""
def __init__(self, strength, midpoint):
self.strength = strength
self.midpoint = midpoint
def __call__(self, x):
diff = tf.abs(x[1:self.midpoint - 1] - x[0:self.midpoint - 2]) \
+ tf.abs(x[self.midpoint + 1:2 * self.midpoint - 1] - x[self.midpoint:2 * self.midpoint - 2])
return self.strength * tf.reduce_sum(diff)
def get_config(self):
return {'strength': self.strength, 'midpoint': self.midpoint}
def cosine_initializer(kx = 1.0):
"""
Initializer for the Einsum layer of the TTL layer class. Sets the weights to a linear combination
of cos(kx * x) and cos(2 * kx * x), where x is the weight vector.
Args:
kx (float, optional): Frequency of the cosine terms. Defaults to 1.0.
Returns:
_initializer: Weight initializer function
"""
def _initializer(shape, dtype = None):
x_values = np.linspace(-np.pi, np.pi, shape[0])
cos_values = np.random.uniform(-0.1, 0.3) * np.abs(np.cos(kx * x_values)) \
+ np.random.uniform(-0.05, 0.05) * np.abs(np.cos(2.0 * kx * x_values))
return tf.convert_to_tensor(-cos_values, dtype = dtype)
return _initializer
class EinsumTTL(tf.keras.layers.Layer):
"""
Layer that contracts the input tensor over the second dimension before passing it to the TTL.
If regularization is enabled, it applies an `EinsumTTLRegularizer` to the kernels.
Attributes:
(nx2, nx3) (integers): Shape parameters characterizing input tensor dimensions. T
The shape of the input tensor is (2*nx2, nx3//2).
W (int): Number of einsum contractions
kernels (list): List of weight matrices for each einsum contraction
regularization_strength (float): The strength of the regularization if used.
use_regularization (bool): Flag to indicate whether regularization is used.
"""
def __init__(self, nx2, nx3, W, use_regularization, regularization_strength = 0.005, **kwargs):
super(EinsumTTL, self).__init__(**kwargs)
self.nx2 = nx2
self.nx3 = nx3
self.W = W
self.kernels = []
self.regularization_strength = regularization_strength
self.use_regularization = use_regularization
if self.use_regularization:
regularizer = EinsumTTLRegularizer(self.regularization_strength, self.nx3 // 4)
else:
regularizer = None
initializer_values_ = [1.0, 0.5, 2.0, 3.0] * W
initializer_values = initializer_values_[:W]
for i in range(W):
self.kernels.append(self.add_weight(
name = f'w{i + 1}',
shape = (nx3 // 2,),
regularizer = regularizer,
initializer = cosine_initializer(initializer_values[i])
))
def call(self, inputs):
parts = []
for w in self.kernels:
part_a = tf.einsum('abc,c->ab', inputs[:, :self.nx2, :self.nx3 // 4], w[:self.nx3 // 4]) + \
tf.einsum('abc,c->ab', inputs[:, :self.nx2, self.nx3 // 4:self.nx3 // 2],
tf.reverse(w[:self.nx3 // 4], axis = [0]))
part_b = tf.einsum('abc,c->ab', inputs[:, self.nx2:, :self.nx3 // 4], w[self.nx3 // 4:self.nx3 // 2]) + \
tf.einsum('abc,c->ab', inputs[:, self.nx2:, self.nx3 // 4:self.nx3 // 2],
tf.reverse(w[self.nx3 // 4:self.nx3 // 2], axis = [0]))
parts.extend([part_a, part_b])
return tf.concat(parts, axis = 1)
def get_config(self):
return {'use_regularization': self.use_regularization,
'regularization_strength': self.regularization_strength}
class TTL(tf.keras.layers.Layer):
"""
TTL (Tensor Train Layer) is a custom TensorFlow Keras layer that builds a model
based on the given configuration. This layer is designed to work with
tensor train decomposition in neural networks.
Attributes:
config (dict): Configuration dictionary containing parameters for the model.
'nx1', 'nx2', 'nx3': Integers, dimensions of the finite-difference grid
'shape1': List of integers, defines the shape of the output tensor in the tensor train format.
The length of shape1 must match the length of shape2.
'shape2': List of integers, specifies the shape of the input tensor in the tensor train format.
The length of shape2 must match the length of shape1.
'ranks': List of integers, represents the ranks in the tensor train decomposition.
The length of this list determines the complexity and the number of parameters in the tensor train layer.
'W' (int): Number of weight vectors to use in the initial EinsumTTL layer. Setting W = 0 means that no EinsumTLL
used.
'use_regularization' (boolean, optional, default: False): Indicates whether regularization is used in the EinsumTTL.
'regularization_strength' (float, optional, default: 0): Strength of the regularization
model (tf.keras.Sequential): The Sequential model built based on the provided configuration.
Methods:
load_config(self, config): Loads configuration
build_model(self): Builds the layer
call(inputs): Method for the forward pass of the layer.
"""
def __init__(self, config, **kwargs):
super(TTL, self).__init__(**kwargs)
self.model = Sequential()
self.nx1 = None
self.nx2 = None
self.nx3 = None
self.shape1 = None
self.shape2 = None
self.ranks = None
self.W = None
self.use_regularization = None
self.regularization_strength = None
self._required_keys = ['nx1', 'nx2', 'nx3', 'shape1', 'shape2', 'ranks', 'W']
config.setdefault('use_regularization', False)
config.setdefault('regularization_strength', 0.0)
self.load_config(config)
self.config = config
self.build_model()
def load_config(self, config):
missing_keys = [key for key in self._required_keys if key not in config]
if missing_keys:
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
if not isinstance(config['use_regularization'], bool):
raise TypeError('use_regularization must be a boolean.')
else:
self.use_regularization = config['use_regularization']
self.regularization_strength = 0.0
for key in ['nx1', 'nx2', 'nx3', 'W']:
if not isinstance(config[key], int):
raise TypeError(f"{key} must be an integer.")
for key in ['nx1', 'nx2', 'nx3']:
if config[key] <= 0:
raise ValueError(f"{key} must be positive.")
if config['W'] < 0:
raise ValueError("W must be non-negative.")
nx1, nx2, nx3 = config['nx1'], config['nx2'], config['nx3']
self.nx1 = nx1
self.nx2 = nx2
self.nx3 = nx3
W = config['W']
self.W = W
input_dim = 2 * nx2 * W
if W == 0:
input_dim = nx2 * nx3
shape1, shape2 = config['shape1'], config['shape2']
if len(shape1) != len(shape2):
raise ValueError(
f'shape1 and shape2 must have the same length. '
f'Received: shape1 = {shape1}, shape2 = {shape2}.'
)
elif np.prod(np.array(shape1)) != nx1 * nx2:
raise ValueError(
f'prod(shape1) must be equal to the output dimension of the TTL '
f'(nx1 * nx2,). Received: prod(shape1) = {np.prod(np.array(shape1))}, '
f'nx1 * nx2 = {nx1 * nx2}.'
)
elif np.prod(np.array(shape2)) != input_dim:
raise ValueError(
f'prod(shape2) must be equal to the input dimension of the TTL '
f'(2 * nx2 * W or nx2 * nx3 if W = 0). '
f'Received: prod(shape2) = {np.prod(np.array(shape2))}, required input dimension = {input_dim}.'
)
else:
self.shape1 = shape1
self.shape2 = shape2
self.ranks = config['ranks']
def build_model(self):
if self.W == 0:
self.model.add(Flatten(input_shape = (2 * self.nx2, self.nx3 // 2)))
else:
self.model.add(EinsumTTL(self.nx2, self.nx3, self.W, self.use_regularization,
regularization_strength = self.regularization_strength,
input_shape = (2 * self.nx2, self.nx3 // 2)))
self.model.add(Flatten())
tt_layer = t3f.nn.KerasDense(input_dims = self.shape2, output_dims = self.shape1,
tt_rank = self.ranks, use_bias = False, activation = 'linear')
self.model.add(tt_layer)
self.model.add(Reshape((self.nx1, self.nx2)))
def call(self, inputs):
return self.model(inputs)