'''
SGD optimizer
'''
from tigercontrol.models.optimizers.core import Optimizer
from tigercontrol.models.optimizers.losses import mse
[docs]class SGD(Optimizer):
"""
Description: Stochastic Gradient Descent optimizer.
Args:
pred (function): a prediction function implemented with jax.numpy
loss (function): specifies loss function to be used; defaults to MSE
learning_rate (float): learning rate
Returns:
None
"""
[docs] def __init__(self, pred=None, loss=mse, learning_rate=0.0001, hyperparameters={}):
self.initialized = False
self.lr = learning_rate
self.hyperparameters = hyperparameters
self.pred = pred
self.loss = loss
if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss(loss, raise_error=False):
self.set_predict(pred, loss=loss)
def update(self, params, x, y, loss=None):
"""
Description: Updates parameters based on correct value, loss and learning rate.
Args:
params (list/numpy.ndarray): Parameters of model pred method
x (float): input to model
y (float): true label
loss (function): loss function. defaults to input value.
Returns:
Updated parameters in same shape as input
"""
assert self.initialized
grad = self.gradient(params, x, y, loss=loss) # defined in optimizers core class
if (type(params) is list):
return [w - self.lr * dw for (w, dw) in zip(params, grad)]
return params - self.lr * grad