tigercontrol.models.optimizers.OGD¶
-
class
tigercontrol.models.optimizers.
OGD
(pred=None, loss=<function mse>, learning_rate=1.0, hyperparameters={})[source]¶ Description: Ordinary Gradient Descent optimizer. :param pred: a prediction function implemented with jax.numpy :type pred: function :param loss: specifies loss function to be used; defaults to MSE :type loss: function :param learning_rate: learning rate :type learning_rate: float
Returns: None -
__init__
(pred=None, loss=<function mse>, learning_rate=1.0, hyperparameters={})[source]¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__
([pred, loss, learning_rate, …])Initialize self. gradient
(params, x, y[, loss])Description: Updates parameters based on correct value, loss and learning rate. set_loss
(new_loss)Description: updates internal loss set_predict
(pred[, loss])Description: Updates internally stored pred and loss functions :param pred: predict function, must take params and x as input :type pred: function :param loss: loss function. update
(params, x, y[, loss])Description: Updates parameters based on correct value, loss and learning rate. -