tigercontrol.models.optimizers.Optimizer

class tigercontrol.models.optimizers.Optimizer(pred=None, loss=<function mse>, learning_rate=1.0, hyperparameters={})[source]

Description: Core class for model optimizers

Parameters:
  • pred (function) – a prediction function implemented with jax.numpy
  • loss (function) – specifies loss function to be used; defaults to MSE
  • learning_rate (float) – learning rate. Default value 0.01
  • hyperparameters (dict) – additional optimizer hyperparameters
Returns:

None

__init__(pred=None, loss=<function mse>, learning_rate=1.0, hyperparameters={})[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__([pred, loss, learning_rate, …]) Initialize self.
gradient(params, x, y[, loss]) Description: Updates parameters based on correct value, loss and learning rate.
set_loss(new_loss) Description: updates internal loss
set_predict(pred[, loss]) Description: Updates internally stored pred and loss functions :param pred: predict function, must take params and x as input :type pred: function :param loss: loss function.