Spaces:
Sleeping
Sleeping
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Reshape, Flatten | |
import t3f | |
import os | |
import logging | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL | |
logging.getLogger('tensorflow').setLevel(logging.FATAL) | |
class SoftmaxEmbeddingLayer(tf.keras.layers.Layer): | |
""" | |
Parameter embedding layer that generates the weights used for stacking the tensor networks. It | |
takes the parameter array, lambda = (ell, a1, a2), as input and outputs K numbers that sum to 1. | |
Attributes: | |
output_dim (int): The dimension of the output | |
expansion_dim (int): The dimension used for expanding the input in intermediate layers. | |
""" | |
def __init__(self, output_dim, d, expansion_dim = 30, **kwargs): | |
super(SoftmaxEmbeddingLayer, self).__init__(**kwargs) | |
self.reduction_layer = None | |
self.expansion_layers = None | |
self.output_dim = output_dim | |
self.expansion_dim = expansion_dim | |
self.d = d # Number of dense layers | |
def build(self, input_shape): | |
# Expansion layers to increase dimensionality | |
self.expansion_layers = [tf.keras.layers.Dense(self.expansion_dim, activation = 'relu') for _ in range(self.d)] | |
# Reduction layer to bring dimensionality back to the desired output dimension | |
self.reduction_layer = tf.keras.layers.Dense(self.output_dim) | |
def call(self, inputs): | |
expanded = inputs | |
for layer in self.expansion_layers: | |
expanded = layer(expanded) | |
return tf.nn.softmax(self.reduction_layer(expanded)) | |
def get_config(self): | |
return {'output_dim': self.output_dim, 'expansion_dim': self.expansion_dim} | |
class EinsumTTLRegularizer(tf.keras.regularizers.Regularizer): | |
""" | |
Regularizer for the Einsum layer of the TTL layer class, penalizing high-frequency components of the | |
weights vector. | |
Attributes: | |
strength (float): The regularization strength. | |
midpoint (int): Index demarcating the inner and outer boundaries, i.e. x[:midpoint] contains | |
data for the inner boundary, and x[midpoint:] contains data for the outer boundary. | |
The regularization is designed so it does not penalize variations across this index. | |
""" | |
def __init__(self, strength, midpoint): | |
self.strength = strength | |
self.midpoint = midpoint | |
def __call__(self, x): | |
diff = tf.abs(x[1:self.midpoint - 1] - x[0:self.midpoint - 2]) \ | |
+ tf.abs(x[self.midpoint + 1:2 * self.midpoint - 1] - x[self.midpoint:2 * self.midpoint - 2]) | |
return self.strength * tf.reduce_sum(diff) | |
def get_config(self): | |
return {'strength': self.strength, 'midpoint': self.midpoint} | |
def cosine_initializer(kx = 1.0): | |
""" | |
Initializer for the Einsum layer of the TTL layer class. Sets the weights to a linear combination | |
of cos(kx * x) and cos(2 * kx * x), where x is the weight vector. | |
Args: | |
kx (float, optional): Frequency of the cosine terms. Defaults to 1.0. | |
Returns: | |
_initializer: Weight initializer function | |
""" | |
def _initializer(shape, dtype = None): | |
x_values = np.linspace(-np.pi, np.pi, shape[0]) | |
cos_values = np.random.uniform(-0.1, 0.3) * np.abs(np.cos(kx * x_values)) \ | |
+ np.random.uniform(-0.05, 0.05) * np.abs(np.cos(2.0 * kx * x_values)) | |
return tf.convert_to_tensor(-cos_values, dtype = dtype) | |
return _initializer | |
class EinsumTTL(tf.keras.layers.Layer): | |
""" | |
Layer that contracts the input tensor over the second dimension before passing it to the TTL. | |
If regularization is enabled, it applies an `EinsumTTLRegularizer` to the kernels. | |
Attributes: | |
(nx2, nx3) (integers): Shape parameters characterizing input tensor dimensions. T | |
The shape of the input tensor is (2*nx2, nx3//2). | |
W (int): Number of einsum contractions | |
kernels (list): List of weight matrices for each einsum contraction | |
regularization_strength (float): The strength of the regularization if used. | |
use_regularization (bool): Flag to indicate whether regularization is used. | |
""" | |
def __init__(self, nx2, nx3, W, use_regularization, regularization_strength = 0.005, **kwargs): | |
super(EinsumTTL, self).__init__(**kwargs) | |
self.nx2 = nx2 | |
self.nx3 = nx3 | |
self.W = W | |
self.kernels = [] | |
self.regularization_strength = regularization_strength | |
self.use_regularization = use_regularization | |
if self.use_regularization: | |
regularizer = EinsumTTLRegularizer(self.regularization_strength, self.nx3 // 4) | |
else: | |
regularizer = None | |
initializer_values_ = [1.0, 0.5, 2.0, 3.0] * W | |
initializer_values = initializer_values_[:W] | |
for i in range(W): | |
self.kernels.append(self.add_weight( | |
name = f'w{i + 1}', | |
shape = (nx3 // 2,), | |
regularizer = regularizer, | |
initializer = cosine_initializer(initializer_values[i]) | |
)) | |
def call(self, inputs): | |
parts = [] | |
for w in self.kernels: | |
part_a = tf.einsum('abc,c->ab', inputs[:, :self.nx2, :self.nx3 // 4], w[:self.nx3 // 4]) + \ | |
tf.einsum('abc,c->ab', inputs[:, :self.nx2, self.nx3 // 4:self.nx3 // 2], | |
tf.reverse(w[:self.nx3 // 4], axis = [0])) | |
part_b = tf.einsum('abc,c->ab', inputs[:, self.nx2:, :self.nx3 // 4], w[self.nx3 // 4:self.nx3 // 2]) + \ | |
tf.einsum('abc,c->ab', inputs[:, self.nx2:, self.nx3 // 4:self.nx3 // 2], | |
tf.reverse(w[self.nx3 // 4:self.nx3 // 2], axis = [0])) | |
parts.extend([part_a, part_b]) | |
return tf.concat(parts, axis = 1) | |
def get_config(self): | |
return {'use_regularization': self.use_regularization, | |
'regularization_strength': self.regularization_strength} | |
class TTL(tf.keras.layers.Layer): | |
""" | |
TTL (Tensor Train Layer) is a custom TensorFlow Keras layer that builds a model | |
based on the given configuration. This layer is designed to work with | |
tensor train decomposition in neural networks. | |
Attributes: | |
config (dict): Configuration dictionary containing parameters for the model. | |
'nx1', 'nx2', 'nx3': Integers, dimensions of the finite-difference grid | |
'shape1': List of integers, defines the shape of the output tensor in the tensor train format. | |
The length of shape1 must match the length of shape2. | |
'shape2': List of integers, specifies the shape of the input tensor in the tensor train format. | |
The length of shape2 must match the length of shape1. | |
'ranks': List of integers, represents the ranks in the tensor train decomposition. | |
The length of this list determines the complexity and the number of parameters in the tensor train layer. | |
'W' (int): Number of weight vectors to use in the initial EinsumTTL layer. Setting W = 0 means that no EinsumTLL | |
used. | |
'use_regularization' (boolean, optional, default: False): Indicates whether regularization is used in the EinsumTTL. | |
'regularization_strength' (float, optional, default: 0): Strength of the regularization | |
model (tf.keras.Sequential): The Sequential model built based on the provided configuration. | |
Methods: | |
load_config(self, config): Loads configuration | |
build_model(self): Builds the layer | |
call(inputs): Method for the forward pass of the layer. | |
""" | |
def __init__(self, config, **kwargs): | |
super(TTL, self).__init__(**kwargs) | |
self.model = Sequential() | |
self.nx1 = None | |
self.nx2 = None | |
self.nx3 = None | |
self.shape1 = None | |
self.shape2 = None | |
self.ranks = None | |
self.W = None | |
self.use_regularization = None | |
self.regularization_strength = None | |
self._required_keys = ['nx1', 'nx2', 'nx3', 'shape1', 'shape2', 'ranks', 'W'] | |
config.setdefault('use_regularization', False) | |
config.setdefault('regularization_strength', 0.0) | |
self.load_config(config) | |
self.config = config | |
self.build_model() | |
def load_config(self, config): | |
missing_keys = [key for key in self._required_keys if key not in config] | |
if missing_keys: | |
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}") | |
if not isinstance(config['use_regularization'], bool): | |
raise TypeError('use_regularization must be a boolean.') | |
else: | |
self.use_regularization = config['use_regularization'] | |
self.regularization_strength = 0.0 | |
for key in ['nx1', 'nx2', 'nx3', 'W']: | |
if not isinstance(config[key], int): | |
raise TypeError(f"{key} must be an integer.") | |
for key in ['nx1', 'nx2', 'nx3']: | |
if config[key] <= 0: | |
raise ValueError(f"{key} must be positive.") | |
if config['W'] < 0: | |
raise ValueError("W must be non-negative.") | |
nx1, nx2, nx3 = config['nx1'], config['nx2'], config['nx3'] | |
self.nx1 = nx1 | |
self.nx2 = nx2 | |
self.nx3 = nx3 | |
W = config['W'] | |
self.W = W | |
input_dim = 2 * nx2 * W | |
if W == 0: | |
input_dim = nx2 * nx3 | |
shape1, shape2 = config['shape1'], config['shape2'] | |
if len(shape1) != len(shape2): | |
raise ValueError( | |
f'shape1 and shape2 must have the same length. ' | |
f'Received: shape1 = {shape1}, shape2 = {shape2}.' | |
) | |
elif np.prod(np.array(shape1)) != nx1 * nx2: | |
raise ValueError( | |
f'prod(shape1) must be equal to the output dimension of the TTL ' | |
f'(nx1 * nx2,). Received: prod(shape1) = {np.prod(np.array(shape1))}, ' | |
f'nx1 * nx2 = {nx1 * nx2}.' | |
) | |
elif np.prod(np.array(shape2)) != input_dim: | |
raise ValueError( | |
f'prod(shape2) must be equal to the input dimension of the TTL ' | |
f'(2 * nx2 * W or nx2 * nx3 if W = 0). ' | |
f'Received: prod(shape2) = {np.prod(np.array(shape2))}, required input dimension = {input_dim}.' | |
) | |
else: | |
self.shape1 = shape1 | |
self.shape2 = shape2 | |
self.ranks = config['ranks'] | |
def build_model(self): | |
if self.W == 0: | |
self.model.add(Flatten(input_shape = (2 * self.nx2, self.nx3 // 2))) | |
else: | |
self.model.add(EinsumTTL(self.nx2, self.nx3, self.W, self.use_regularization, | |
regularization_strength = self.regularization_strength, | |
input_shape = (2 * self.nx2, self.nx3 // 2))) | |
self.model.add(Flatten()) | |
tt_layer = t3f.nn.KerasDense(input_dims = self.shape2, output_dims = self.shape1, | |
tt_rank = self.ranks, use_bias = False, activation = 'linear') | |
self.model.add(tt_layer) | |
self.model.add(Reshape((self.nx1, self.nx2))) | |
def call(self, inputs): | |
return self.model(inputs) | |