import numpy as np
from typing import Any
[docs]class Loss(object):
[docs] def forward(self, y, yhat):
raise NotImplementedError()
[docs] def backward(self, y, yhat):
raise NotImplementedError()
[docs]class Module(object):
def __init__(self):
self._parameters = {}
self._gradient = {}
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.forward(*args, **kwds)
[docs] def calculate_gain(self):
if self.__class__.__name__.lower() == "tanh":
return 5 / 3
elif self.__class__.__name__.lower() == "relu":
return np.sqrt(2)
else:
return 1.0
[docs] def zero_grad(self):
"""Réinitialise le gradient."""
raise NotImplementedError()
[docs] def forward(self, X):
"""Passe forward."""
raise NotImplementedError()
[docs] def update_parameters(self, learning_rate=1e-3):
r"""Update the parameters according to the calculated gradient and the learning
rate.
"""
# self._parameters -= learning_rate * self._gradient
raise NotImplementedError()
[docs] def backward_update_gradient(self, input, delta):
r"""Update gradient value given module.
.. math::
\frac{\partial L}{\partial w_i^h}=\sum_k \frac{\partial L}{\partial z_k^h}
\frac{\partial z_k^h}{\partial w_i^h}=\sum_k \delta_k^h
\frac{\partial z_k^h}{\partial w_i^h},
\text { let } \nabla_{\mathbf{w}^h} L=\left(\begin{array}{ccc}
\frac{\partial z_1^h}{\partial w_1^h} & \frac{\partial z_2^h}{\partial w_1^h}
& \cdots \\ \frac{\partial z_1^h}{\partial w_2^h} & \ddots & \\
\vdots & \end{array}\right) \nabla_{\mathbf{z}^h L}
"""
raise NotImplementedError()
[docs] def backward_delta(self, input, delta):
r"""Calculates the derivative of the error and the next delta (derivative of the
module with respect to the to the inputs).
.. math::
\delta_j^{h-1}=\frac{\partial L}{\partial z_j^{h-1}}=\sum_k
\frac{\partial L}{\partial z_k^h} \frac{\partial z_k^h}{\partial z_j^{h-1}},
\text { let } \nabla_{\mathbf{z}^{h-1}} L=\left(\begin{array}{ccc}
\frac{\partial z_1^h}{z_1^{h-1}} & \frac{\partial z_2^h}{z_1^{h-1}} & \cdots \\
\frac{\partial z_2^h}{z_2^{h-1}} & \ddots & \cdots \\
\vdots & \end{array}\right) \nabla_{\mathbf{z}^h L}
"""
raise NotImplementedError()