tigercontrol.models.optimizers.SGD¶
-
class
tigercontrol.models.optimizers.SGD(pred=None, loss=<function mse>, learning_rate=0.0001, hyperparameters={})[source]¶ Description: Stochastic Gradient Descent optimizer. :param pred: a prediction function implemented with jax.numpy :type pred: function :param loss: specifies loss function to be used; defaults to MSE :type loss: function :param learning_rate: learning rate :type learning_rate: float
Returns: None -
__init__(pred=None, loss=<function mse>, learning_rate=0.0001, hyperparameters={})[source]¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__([pred, loss, learning_rate, …])Initialize self. gradient(params, x, y[, loss])Description: Updates parameters based on correct value, loss and learning rate. set_loss(new_loss)Description: updates internal loss set_predict(pred[, loss])Description: Updates internally stored pred and loss functions :param pred: predict function, must take params and x as input :type pred: function :param loss: loss function. update(params, x, y[, loss])Description: Updates parameters based on correct value, loss and learning rate. -