Source code for nnodely.parametricfunction

import inspect, copy, textwrap, torch, math

import torch.nn as nn
import numpy as np

from collections.abc import Callable

from nnodely.relation import NeuObj, Stream, toStream
from nnodely.model import Model
from nnodely.parameter import Parameter, Constant
from nnodely.utils import check, merge, enforce_types


paramfun_relation_name = 'ParamFun'

[docs] class ParamFun(NeuObj): """ Represents a parametric function in the neural network model. Parameters ---------- param_fun : Callable The parametric function to be used. constants : list or dict or None, optional A list or dictionary of constants to be used in the function. Default is None. parameters_dimensions : list or dict or None, optional A list or dictionary specifying the dimensions of the parameters. Default is None. parameters : list or dict or None, optional A list or dictionary of parameters to be used in the function. Default is None. map_over_batch : bool, optional A boolean indicating whether to map the function over the batch dimension. Default is False. Attributes ---------- relation_name : str The name of the relation. param_fun : Callable The parametric function to be used. constants : list or dict or None A list or dictionary of constants to be used in the function. parameters_dimensions : list or dict or None A list or dictionary specifying the dimensions of the parameters. parameters : list or dict or None A list or dictionary of parameters to be used in the function. map_over_batch : bool A boolean indicating whether to map the function over the batch dimension. output_dimension : dict A dictionary containing the output dimensions of the function. json : dict A dictionary containing the configuration of the function. Example ------- >>> input1 = Input('input1') >>> input2 = Input('input2') >>> def my_function(x, y, param1, const1): >>> return param1 * x + const1 * y >>> param_fun = ParamFun(my_function, constants={'const1': 1.0}, parameters_dimensions={'param1': 1}) >>> result = param_fun(input1, input2) """ @enforce_types def __init__(self, param_fun:Callable, constants:list|dict|None = None, parameters_dimensions:list|dict|None = None, parameters:list|dict|None = None, map_over_batch:bool = False) -> Stream: self.relation_name = paramfun_relation_name # input parameters self.param_fun = param_fun self.constants = constants self.parameters_dimensions = parameters_dimensions self.parameters = parameters self.map_over_batch = map_over_batch self.output_dimension = {} super().__init__('F'+paramfun_relation_name + str(NeuObj.count)) code = textwrap.dedent(inspect.getsource(param_fun)).replace('\"', '\'') self.json['Functions'][self.name] = { 'code' : code, 'name' : param_fun.__name__, } self.json['Functions'][self.name]['params_and_consts'] = [] def __call__(self, *obj): stream_name = paramfun_relation_name + str(Stream.count) funinfo = inspect.getfullargspec(self.param_fun) n_function_input = len(funinfo.args) n_call_input = len(obj) n_new_constants_and_params = n_function_input - n_call_input if 'n_input' not in self.json['Functions'][self.name]: self.json['Functions'][self.name]['n_input'] = n_call_input self.__set_params_and_consts(n_new_constants_and_params) input_dimensions = [] input_types = [] for ind, o in enumerate(obj): if type(o) in (int,float,list): obj_type = Constant else: obj_type = type(o) o = toStream(o) check(type(o) is Stream, TypeError, f"The type of {o} is {type(o)} and is not supported for ParamFun operation.") input_types.append(obj_type) input_dimensions.append(o.dim) self.json['Functions'][self.name]['in_dim'] = copy.deepcopy(input_dimensions) self.__infer_output_dimensions(input_types, input_dimensions) self.json['Functions'][self.name]['out_dim'] = copy.deepcopy(self.output_dimension) # Create the missing parameters missing_params = n_new_constants_and_params - len(self.json['Functions'][self.name]['params_and_consts']) check(missing_params == 0, ValueError, f"The function is called with different number of inputs.") stream_json = copy.deepcopy(self.json) input_names = [] for ind, o in enumerate(obj): o = toStream(o) check(type(o) is Stream, TypeError, f"The type of {o} is {type(o)} and is not supported for ParamFun operation.") stream_json = merge(stream_json, o.json) input_names.append(o.name) output_dimension = copy.deepcopy(self.output_dimension) stream_json['Relations'][stream_name] = [paramfun_relation_name, input_names, self.name] return Stream(stream_name, stream_json, output_dimension) def __set_params_and_consts(self, n_new_constants_and_params): funinfo = inspect.getfullargspec(self.param_fun) # Create the missing constants from list if type(self.constants) is list: for const in self.constants: if type(const) is Constant: self.json['Functions'][self.name]['params_and_consts'].append(const.name) self.json['Constants'][const.name] = copy.deepcopy(const.json['Constants'][const.name]) elif type(const) is str: self.json['Functions'][self.name]['params_and_consts'].append(const) self.json['Constants'][const] = {'dim': 1} else: check(type(const) is Constant or type(const) is str, TypeError, 'The element inside the \"constants\" list must be a Constant or str') # Create the missing parameters from list if type(self.parameters) is list: check(self.parameters_dimensions is None, ValueError, '\"parameters_dimensions\" must be None if \"parameters\" is set using list') for param in self.parameters: if type(param) is Parameter: self.json['Functions'][self.name]['params_and_consts'].append(param.name) self.json['Parameters'][param.name] = copy.deepcopy(param.json['Parameters'][param.name]) elif type(param) is str: self.json['Functions'][self.name]['params_and_consts'].append(param) self.json['Parameters'][param] = {'dim': 1} else: check(type(param) is Parameter or type(param) is str, TypeError, 'The element inside the \"parameters\" list must be a Parameter or str') elif type(self.parameters_dimensions) is list: for i, param_dim in enumerate(self.parameters_dimensions): idx = i + len(funinfo.args) - len(self.parameters_dimensions) param_name = self.name + str(idx) self.json['Functions'][self.name]['params_and_consts'].append(param_name) self.json['Parameters'][param_name] = {'dim': list(self.parameters_dimensions[i])} # Create the missing parameters and constants from dict missing_params = n_new_constants_and_params - len(self.json['Functions'][self.name]['params_and_consts']) if missing_params or type(self.constants) is dict or type(self.parameters) is dict or type(self.parameters_dimensions) is dict: n_input = len(funinfo.args) - missing_params n_elem_dict = (len(self.constants if type(self.constants) is dict else []) + len(self.parameters if type(self.parameters) is dict else []) + len(self.parameters_dimensions if type(self.parameters_dimensions) is dict else [])) for i, key in enumerate(funinfo.args): if i >= n_input: if type(self.parameters) is dict and key in self.parameters: if self.parameters_dimensions: check(key in self.parameters_dimensions, TypeError, f'The parameter {key} must be removed from \"parameters_dimensions\".') param = self.parameters[key] if type(self.parameters[key]) is Parameter: self.json['Functions'][self.name]['params_and_consts'].append(param.name) self.json['Parameters'][param.name] = copy.deepcopy(param.json['Parameters'][param.name]) elif type(self.parameters[key]) is str: self.json['Functions'][self.name]['params_and_consts'].append(param) self.json['Parameters'][param] = {'dim': 1} else: check(type(param) is Parameter or type(param) is str, TypeError, 'The element inside the \"parameters\" dict must be a Parameter or str') n_elem_dict -= 1 elif type(self.parameters_dimensions) is dict and key in self.parameters_dimensions: param_name = self.name + key dim = self.parameters_dimensions[key] check(isinstance(dim,(list,tuple,int)), TypeError, 'The element inside the \"parameters_dimensions\" dict must be a tuple or int') self.json['Functions'][self.name]['params_and_consts'].append(param_name) self.json['Parameters'][param_name] = {'dim': list(dim) if type(dim) is tuple else dim} n_elem_dict -= 1 elif type(self.constants) is dict and key in self.constants: const = self.constants[key] if type(self.constants[key]) is Constant: self.json['Functions'][self.name]['params_and_consts'].append(const.name) self.json['Constants'][const.name] = copy.deepcopy(const.json['Constants'][const.name]) elif type(self.constants[key]) is str: self.json['Functions'][self.name]['params_and_consts'].append(const) self.json['Constants'][const] = {'dim': 1} else: check(type(const) is Constant or type(const) is str, TypeError, 'The element inside the \"constants\" dict must be a Constant or str') n_elem_dict -= 1 else: param_name = self.name + key self.json['Functions'][self.name]['params_and_consts'].append(param_name) self.json['Parameters'][param_name] = {'dim': 1} check(n_elem_dict == 0, ValueError, 'Some of the input parameters are not used in the function.') def __infer_output_dimensions(self, input_types, input_dimensions): import torch batch_dim = 5 all_inputs_dim = input_dimensions all_inputs_type = input_types params_and_consts = self.json['Constants'] | self.json['Parameters'] for name in self.json['Functions'][self.name]['params_and_consts']: all_inputs_dim.append(params_and_consts[name]) all_inputs_type.append(Constant) n_samples_sec = 0.1 is_int = False while is_int == False: n_samples_sec *= 10 vect_input_time = [math.isclose(d['tw']*n_samples_sec,round(d['tw']*n_samples_sec)) for d in all_inputs_dim if 'tw' in d] if len(vect_input_time) == 0: is_int = True else: is_int = sum(vect_input_time) == len(vect_input_time) # Build input with right dimensions inputs = [] inputs_win_type = [] inputs_win = [] input_map_dim = () for t, dim in zip(all_inputs_type,all_inputs_dim): window = 'tw' if 'tw' in dim else ('sw' if 'sw' in dim else None) if window == 'tw': dim_win = round(dim[window] * n_samples_sec) elif window == 'sw': dim_win = dim[window] else: dim_win = 1 if t in (Parameter, Constant): if self.map_over_batch: input_map_dim += (None,) if type(dim['dim']) is list: inputs.append(torch.rand(size=(dim_win,) + tuple(dim['dim']))) else: inputs.append(torch.rand(size=(dim_win, dim['dim']))) else: inputs.append(torch.rand(size=(batch_dim, dim_win, dim['dim']))) if self.map_over_batch: input_map_dim += (0,) inputs_win_type.append(window) inputs_win.append(dim_win) if self.map_over_batch: self.json['Functions'][self.name]['map_over_dim'] = list(input_map_dim) function_to_call = torch.func.vmap(self.param_fun,in_dims=input_map_dim) else: self.json['Functions'][self.name]['map_over_dim'] = False function_to_call = self.param_fun out = function_to_call(*inputs) out_shape = out.shape check(out_shape[0] == batch_dim, ValueError, "The batch output dimension it is not correct.") out_dim = list(out_shape[2:]) check(len(out_dim) == 1, ValueError, "The output dimension of the function is bigger than a vector.") out_win_from_input = False for idx, win in enumerate(inputs_win): if out_shape[1] == win and all_inputs_type[idx] not in (Parameter, Constant): out_win_from_input = True out_win_type = inputs_win_type[idx] out_win = all_inputs_dim[idx][out_win_type] if out_win_from_input == False: out_win_type = 'sw' out_win = out_shape[1] self.output_dimension = {'dim': out_dim[0], out_win_type : out_win}
def return_standard_inputs(json, model_def, xlim = None, num_points = 1000): check(json['n_input'] == 1 or json['n_input'] == 2, ValueError, "The function must have only one or two inputs.") fun_inputs = tuple() for i in range(json['n_input']): dim = json['in_dim'][i] check(dim['dim'] == 1, ValueError, "The input dimension must be 1.") if 'tw' in dim: check(dim['tw'] == model_def['Info']['SampleTime'], ValueError, "The input window must be 1.") elif 'sw' in dim: check(dim['sw'] == 1, ValueError, "The input window must be 1.") if xlim is not None: if json['n_input'] == 2: check(np.array(xlim).shape == (json['n_input'], 2), ValueError, "The xlim must have the same shape as the number of inputs.") x_value = np.linspace(xlim[i][0], xlim[i][1], num=num_points) else: check(np.array(xlim).shape == (2,), ValueError, "The xlim must have the same shape as the number of inputs.") x_value = np.linspace(xlim[0], xlim[1], num=num_points) else: x_value = np.linspace(0, 1, num=num_points) if i == 0: x0_value = torch.from_numpy(x_value) else: x1_value = torch.from_numpy(x_value) if json['n_input'] == 2: x0_value, x1_value = torch.meshgrid(x0_value,x1_value,indexing="xy") x0_value = x0_value.flatten().unsqueeze(1).unsqueeze(1) x1_value = x1_value.flatten().unsqueeze(1).unsqueeze(1) fun_inputs += (x0_value,x1_value,) else: x0_value = x0_value.unsqueeze(1).unsqueeze(1) fun_inputs += (x0_value,) for key in json['params_and_consts']: val = model_def['Parameters'][key] if key in model_def['Parameters'] else model_def['Constants'][key] fun_inputs += tuple([torch.from_numpy(np.array(val['values']))]) # The vector is transform in a tuple return fun_inputs def return_function(json, fun_inputs): exec(json['code'], globals()) function_to_call = globals()[json['name']] output = function_to_call(*fun_inputs) check(output.shape[1] == 1, ValueError, "The output dimension must be 1.") check(output.shape[2] == 1, ValueError, "The output window must be 1.") funinfo = inspect.getfullargspec(function_to_call) return output, funinfo.args class Parametric_Layer(nn.Module): def __init__(self, func, params_and_consts, map_over_batch): super().__init__() self.name = func['name'] self.params_and_consts = params_and_consts if type(map_over_batch) is list: self.map_over_batch = True self.input_map_dim = tuple(map_over_batch) else: self.map_over_batch = False ## Add the function to the globals try: code = 'import torch\n@torch.fx.wrap\n' + func['code'] exec(code, globals()) except Exception as e: print(f"An error occurred: {e}") def forward(self, *inputs): args = list(inputs) + self.params_and_consts # Retrieve the function object from the globals dictionary function_to_call = globals()[self.name] # Call the function using the retrieved function object if self.map_over_batch: function_to_call = torch.func.vmap(function_to_call,in_dims=self.input_map_dim) result = function_to_call(*args) return result def createParamFun(self, *func_params): return Parametric_Layer(func=func_params[0], params_and_consts=func_params[1], map_over_batch=func_params[2]) setattr(Model, paramfun_relation_name, createParamFun)