tigercontrol.models.optimizers.OGD

class tigercontrol.models.optimizers.OGD(pred=None, loss=<function mse>, learning_rate=1.0, hyperparameters={})[source]

Description: Ordinary Gradient Descent optimizer. :param pred: a prediction function implemented with jax.numpy :type pred: function :param loss: specifies loss function to be used; defaults to MSE :type loss: function :param learning_rate: learning rate :type learning_rate: float

Returns:None
__init__(pred=None, loss=<function mse>, learning_rate=1.0, hyperparameters={})[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__([pred, loss, learning_rate, …]) Initialize self.
gradient(params, x, y[, loss]) Description: Updates parameters based on correct value, loss and learning rate.
set_loss(new_loss) Description: updates internal loss
set_predict(pred[, loss]) Description: Updates internally stored pred and loss functions :param pred: predict function, must take params and x as input :type pred: function :param loss: loss function.
update(params, x, y[, loss]) Description: Updates parameters based on correct value, loss and learning rate.