From a1a2d0fbc4341a1e56fdb92115d89c19b501743a Mon Sep 17 00:00:00 2001 From: Claudio Scheer Date: Wed, 8 Apr 2020 17:38:45 -0300 Subject: [PATCH] Add multiplication and sum nodes --- src/nodes/multiplication.py | 24 ++++++++++++++++-------- src/nodes/sum.py | 16 ++++++++++++---- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/nodes/multiplication.py b/src/nodes/multiplication.py index ea83930..aaaa1ac 100644 --- a/src/nodes/multiplication.py +++ b/src/nodes/multiplication.py @@ -2,18 +2,26 @@ from . import BaseNode import numpy as np +def _f(q, z): + return q * z + + class Multiplication(BaseNode): - def __init__(self, x, y): + def __init__(self, q, z): """ - x * y + q * y """ super(Multiplication, self).__init__() - self.x = x - self.y = y + self.q = q + self.z = z def forward(self): - self.result = self.x * self.y - return self.result + self.forward_result = _f(self.q, self.z) + return self.forward_result - def backward(self): - pass + def backward(self, previous): + # store the partial derivative for each input + self.local_gradient_q = (_f(self.q + 1e-7, self.z) - self.forward_result) / 1e-7 + self.local_gradient_z = (_f(self.q, self.z + 1e-7) - self.forward_result) / 1e-7 + self.backward_gradient_q = previous * self.local_gradient_q + self.backward_gradient_z = previous * self.local_gradient_z diff --git a/src/nodes/sum.py b/src/nodes/sum.py index db7622c..0d950b7 100644 --- a/src/nodes/sum.py +++ b/src/nodes/sum.py @@ -2,6 +2,10 @@ from . import BaseNode import numpy as np +def _f(x, y): + return x + y + + class Sum(BaseNode): def __init__(self, x, y): """ @@ -12,8 +16,12 @@ class Sum(BaseNode): self.y = y def forward(self): - self.result = self.x + self.y - return self.result + self.forward_result = _f(self.x, self.y) + return self.forward_result - def backward(self): - pass + def backward(self, previous): + # store the partial derivative for each input + self.local_gradient_x = (_f(self.x + 1e-7, self.y) - self.forward_result) / 1e-7 + self.local_gradient_y = (_f(self.x, self.y + 1e-7) - self.forward_result) / 1e-7 + self.backward_gradient_x = previous * self.local_gradient_x + self.backward_gradient_y = previous * self.local_gradient_y -- GitLab