{ "cells": [ { "cell_type": "raw", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "\n", " \"Open\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Parametric Functions\n", "\n", "Create custom parametric functions inside the neural network model." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>-- nnodely_v1.5.0 --<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n" ] } ], "source": [ "# uncomment the command below to install the nnodely package\n", "#!pip install nnodely\n", "\n", "from nnodely import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic usage\n", "\n", "Create a simple parametrc function That has two parameters p1 and p2 of size 1 and two inputs K1 and K2. \n", "\n", "(The output dimension should be defined by the user. If not specified, the output is expected to be 1)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def myFun(K1,K2,p1,p2):\n", " import torch\n", " return p1*K1+p2*torch.sin(K2)\n", "\n", "x = Input('x')\n", "F = Input('F')\n", "\n", "parfun = ParamFun(myFun)\n", "out = Output('out',parfun(x.last(),F.last()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parameter dimension\n", "In this case the size of the parameter is specified. the first p1 is a 4 row column vector. The time dimension of the output is not defined but it depends on the input." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33m[check_names] The name 'x' is already in defined as NeuObj but it is overwritten.\u001b[0m\n", "\u001b[33m[check_names] The name 'F' is already in defined as NeuObj but it is overwritten.\u001b[0m\n", "\u001b[33m[check_names] The name 'out' is already in defined as NeuObj but it is overwritten.\u001b[0m\n" ] } ], "source": [ "def myFun(K1,K2,p1):\n", " import torch\n", " return torch.stack([K1,2*K1,3*K1,4*K1],dim=2).squeeze(-1)*p1+K2\n", "\n", "x=Input('x')\n", "F=Input('F')\n", "parfun = ParamFun(myFun, parameters_and_constants = {'p1':(1,4)})\n", "out = Output('out',parfun(x.last(),F.last()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Passing custom parameters\n", "Create a custom parameter to use inside the parametric function. Here the parametric function takes a parameter of size 1 and tw = 1\n", "The function has two inputs, the first two are inputs and the second is a K parameter. The function creates a tensor and performs a dot product between input 1 and p1 (which is effectively the custom parameter 'K')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33m[check_names] The name 'x' is already in defined as NeuObj but it is overwritten.\u001b[0m\n", "\u001b[33m[check_names] The name 'out' is already in defined as NeuObj but it is overwritten.\u001b[0m\n" ] } ], "source": [ "def myFun(K1,p1):\n", " return K1*p1\n", "\n", "x = Input('x')\n", "K = Parameter('k', dimensions = 1, sw = 1,values=[[2.0]])\n", "parfun = ParamFun(myFun, parameters_and_constants=[K])\n", "out = Output('out',parfun(x.sw(1)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parametric function with different parameters\n", "\n", "The same parametric function can be called passing various parameters." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33m[check_names] The name 'x' is already in defined as NeuObj but it is overwritten.\u001b[0m\n", "\u001b[33m[check_names] The name 'out' is already in defined as NeuObj but it is overwritten.\u001b[0m\n" ] } ], "source": [ "def myFun(K1,p1):\n", " return K1*p1\n", "\n", "K = Parameter('k1', dimensions = 1, tw = 1, values=[[2.0],[3.0],[4.0],[5.0]])\n", "R = Parameter('r1', dimensions = 1, tw = 1, values=[[5.0],[4.0],[3.0],[2.0]])\n", "\n", "x = Input('x')\n", "parfun = ParamFun(myFun)\n", "out = Output('out',parfun(x.tw(1),K)+parfun(x.tw(1),R))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parametric functions with Constants\n", "\n", "parametric functions work also with constant values" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33m[check_names] The name 'x' is already in defined as NeuObj but it is overwritten.\u001b[0m\n", "\u001b[33m[check_names] The name 'out' is already in defined as NeuObj but it is overwritten.\u001b[0m\n" ] } ], "source": [ "def myFun(K1,p1):\n", " return K1*p1\n", "\n", "parfun = ParamFun(myFun)\n", "x = Input('x')\n", "c = Constant('c',values=[[5.0],[4.0],[3.0],[2.0]])\n", "out = Output('out',parfun(x.sw(4),c))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Mapping over a batch\n", "\n", "By setting the argument 'map_over_batch' to True, the parametric function will be mapped over the batch dimension." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33m[check_names] The name 'x' is already in defined as NeuObj but it is overwritten.\u001b[0m\n", "k:torch.Size([4, 1])\n", "p1:torch.Size([1])\n", "\u001b[33m[check_names] The name 'out' is already in defined as NeuObj but it is overwritten.\u001b[0m\n" ] } ], "source": [ "def myFun(k, p1):\n", " print(f'k:{k.shape}')\n", " print(f'p1:{p1.shape}')\n", " return k * p1\n", "\n", "x = Input('x')\n", "parfun = ParamFun(myFun, map_over_batch=True)\n", "out = Output('out',parfun(x.sw(4)))" ] } ], "metadata": { "kernelspec": { "display_name": "venv (3.11.4)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 2 }