Source code for nnodely.layers.equationlearner

import inspect

from nnodely.basic.relation import NeuObj, Stream
from nnodely.support.utils import check, enforce_types

from nnodely.layers.linear import Linear
from nnodely.layers.part import Select, Concatenate
from nnodely.layers.fuzzify import Fuzzify
from nnodely.layers.parametricfunction import ParamFun
from nnodely.layers.activation import Relu, ELU, Identity, Sigmoid
from nnodely.layers.trigonometric import Sin, Cos, Tan, Tanh, Cosh, Sech
from nnodely.layers.arithmetic import Add, Mul, Sub, Neg, Pow, Sum

equationlearner_relation_name = 'EquationLearner'
Available_functions = [Sin, Cos, Tan, Cosh, Tanh, Sech, Add, Mul, Sub, Neg, Pow, Sum, Concatenate, Relu, ELU, Identity, Sigmoid]
Initialized_functions = [ParamFun, Fuzzify]

[docs] class EquationLearner(NeuObj): """ Represents a nnodely implementation of the Task-Parametrized Equation Learner block. See also: Task-Parametrized Equation Learner official paper: `Equation Learner <https://www.sciencedirect.com/science/article/pii/S0921889022001981>`_ Parameters ---------- functions : list A list of callable functions to be used as activation functions. linear_in : Linear, optional A Linear layer to process the input before applying the activation functions. If not provided a random initialized linear layer will be used instead. linear_out : Linear, optional A Linear layer to process the output after applying the activation functions. Can be omitted. Attributes ---------- relation_name : str The name of the relation. linear_in : Linear or None The Linear layer to process the input. linear_out : Linear or None The Linear layer to process the output. functions : list The list of activation functions. func_parameters : dict A dictionary mapping function indices to the number of parameters they require. n_activations : int The total number of activation functions. Examples -------- .. include:: /examples_basics/layer_module_ex/eql.rst """ @enforce_types def __init__(self, functions:list, *, linear_in:Linear|None = None, linear_out:Linear|None = None) -> Stream: self.relation_name = equationlearner_relation_name self.linear_in = linear_in self.linear_out = linear_out # input parameters self.functions = functions super().__init__(equationlearner_relation_name + str(NeuObj.count)) self.func_parameters = {} for func_idx, func in enumerate(self.functions): check(callable(func), TypeError, 'The activation functions must be callable') if type(func) in Initialized_functions: if type(func) == ParamFun: funinfo = inspect.getfullargspec(func.param_fun) num_args = len(funinfo.args) - len(func.parameters_and_constants) if func.parameters_and_constants else len(funinfo.args) elif type(func) == Fuzzify: init_signature = inspect.signature(func.__call__) parameters = list(init_signature.parameters.values()) num_args = len([param for param in parameters if param.name != "self"]) else: check(func in Available_functions, ValueError, f'The function {func} is not available for the EquationLearner operation') init_signature = inspect.signature(func.__init__) parameters = list(init_signature.parameters.values()) num_args = len([param for param in parameters if param.name != "self"]) self.func_parameters[func_idx] = num_args self.n_activations = sum(self.func_parameters.values()) check(self.n_activations > 0, ValueError, 'At least one activation function must be provided') def __call__(self, inputs): if type(inputs) is not tuple: inputs = (inputs,) check(len(set([x.dim['sw'] if 'sw' in x.dim.keys() else x.dim['tw'] for x in inputs])) == 1, ValueError, 'All inputs must have the same time dimension') concatenated_input = inputs[0] for inp in inputs[1:]: concatenated_input = Concatenate(concatenated_input, inp) linear_layer = self.linear_in(concatenated_input) if self.linear_in else Linear(output_dimension=self.n_activations, b=True)(concatenated_input) idx, out = 0, None for func_idx, func in enumerate(self.functions): arguments = [Select(linear_layer,idx+arg_idx) for arg_idx in range(self.func_parameters[func_idx])] idx += self.func_parameters[func_idx] out = func(*arguments) if func_idx == 0 else Concatenate(out, func(*arguments)) if self.linear_out: out = self.linear_out(out) return out