diff --git a/src/nodes/multiplication.py b/src/nodes/multiplication.py index ea839308f5c509bf0fb34206ae8e7c909e9605d3..aaaa1aca165f89cbc5c71cf5eb8013bca9d09390 100644 --- a/src/nodes/multiplication.py +++ b/src/nodes/multiplication.py @@ -2,18 +2,26 @@ from . import BaseNode import numpy as np +def _f(q, z): + return q * z + + class Multiplication(BaseNode): - def __init__(self, x, y): + def __init__(self, q, z): """ - x * y + q * y """ super(Multiplication, self).__init__() - self.x = x - self.y = y + self.q = q + self.z = z def forward(self): - self.result = self.x * self.y - return self.result + self.forward_result = _f(self.q, self.z) + return self.forward_result - def backward(self): - pass + def backward(self, previous): + # store the partial derivative for each input + self.local_gradient_q = (_f(self.q + 1e-7, self.z) - self.forward_result) / 1e-7 + self.local_gradient_z = (_f(self.q, self.z + 1e-7) - self.forward_result) / 1e-7 + self.backward_gradient_q = previous * self.local_gradient_q + self.backward_gradient_z = previous * self.local_gradient_z diff --git a/src/nodes/sum.py b/src/nodes/sum.py index db7622ce96271eff64a8bc880c89b75b52881d54..0d950b70df6f467562cfb215fa84e385d3563768 100644 --- a/src/nodes/sum.py +++ b/src/nodes/sum.py @@ -2,6 +2,10 @@ from . import BaseNode import numpy as np +def _f(x, y): + return x + y + + class Sum(BaseNode): def __init__(self, x, y): """ @@ -12,8 +16,12 @@ class Sum(BaseNode): self.y = y def forward(self): - self.result = self.x + self.y - return self.result + self.forward_result = _f(self.x, self.y) + return self.forward_result - def backward(self): - pass + def backward(self, previous): + # store the partial derivative for each input + self.local_gradient_x = (_f(self.x + 1e-7, self.y) - self.forward_result) / 1e-7 + self.local_gradient_y = (_f(self.x, self.y + 1e-7) - self.forward_result) / 1e-7 + self.backward_gradient_x = previous * self.local_gradient_x + self.backward_gradient_y = previous * self.local_gradient_y