Source code for tigercontrol.models.time_series.rnn

"""
Recurrent neural network model
"""

import jax
import jax.numpy as np
import jax.experimental.stax as stax
import tigercontrol
from tigercontrol.utils.random import generate_key
from tigercontrol.models.time_series import TimeSeriesModel
from tigercontrol.models.optimizers import *
from tigercontrol.models.optimizers.losses import mse

[docs]class RNN(TimeSeriesModel): """ Description: Produces outputs from a randomly initialized recurrent neural network. """ compatibles = set(['TimeSeries'])
[docs] def __init__(self): self.initialized = False self.uses_regressors = True
def initialize(self, n=1, m=1, l = 32, h = 64, optimizer = OGD, loss = mse, lr = 0.003): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. l (int): Length of memory for update step purposes. h (int): Default value 64. Hidden dimension of RNN. optimizer (class): optimizer choice loss (class): loss choice lr (float): learning rate for update """ self.T = 0 self.initialized = True self.n, self.m, self.l, self.h = n, m, l, h # initialize parameters glorot_init = stax.glorot() # returns a function that initializes weights W_h = glorot_init(generate_key(), (h, h)) W_x = glorot_init(generate_key(), (h, n)) W_out = glorot_init(generate_key(), (m, h)) b_h = np.zeros(h) self.params = [W_h, W_x, W_out, b_h] self.hid = np.zeros(h) self.x = np.zeros((l, n)) """ private helper methods""" @jax.jit def _update_x(self_x, x): new_x = np.roll(self_x, -self.n) new_x = jax.ops.index_update(new_x, jax.ops.index[-1,:], x) return new_x @jax.jit def _fast_predict(carry, x): params, hid = carry # unroll tuple in carry W_h, W_x, W_out, b_h = params next_hid = np.tanh(np.dot(W_h, hid) + np.dot(W_x, x) + b_h) y = np.dot(W_out, next_hid) return (params, next_hid), y @jax.jit def _predict(params, x): _, y = jax.lax.scan(_fast_predict, (params, np.zeros(h)), x) return y[-1] self.transform = lambda x: float(x) if (self.m == 1) else x self._update_x = _update_x self._fast_predict = _fast_predict self._predict = _predict self._store_optimizer(optimizer, self._predict) def to_ndarray(self, x): """ Description: If x is a scalar, transform it to a (1, 1) numpy.ndarray; otherwise, leave it unchanged. Args: x (float/numpy.ndarray) Returns: A numpy.ndarray representation of x """ x = np.asarray(x) if np.ndim(x) == 0: x = x[None] return x def predict(self, x, timeline = 1): """ Description: Predict next value given observation Args: x (float/numpy.ndarray): Observation Returns: Predicted value for the next time-step """ assert self.initialized self.x = self._update_x(self.x, self.to_ndarray(x)) carry, y = self._fast_predict((self.params, self.hid), self.to_ndarray(x)) _, self.hid = carry return y def forecast(self, x, timeline = 1): """ Description: Forecast values 'timeline' timesteps in the future Args: x (float/numpy.ndarray): Value at current time-step timeline (int): timeline for forecast Returns: Forecasted values 'timeline' timesteps in the future """ assert self.initialized self.x = self._update_x(self.x, self.to_ndarray(x)) carry, x = self._fast_predict((self.params, self.hid), self.to_ndarray(x)) _, self.hid = carry hid = self.hid pred = [self.transform(x)] for t in range(timeline - 1): carry, x = self._fast_predict((self.params, hid), self.to_ndarray(x)) _, self.hid = carry pred.append(self.transform(x)) return pred def update(self, y): """ Description: Updates parameters Args: y (int/numpy.ndarray): True value at current time-step Returns: None """ self.params = self.optimizer.update(self.params, self.x, y) return def help(self): """ Description: Prints information about this class and its methods. Args: None Returns: None """ print(RNN_help)
# string to print when calling help() method RNN_help = """ -------------------- *** -------------------- Id: RNN Description: Implements a Recurrent Neural Network model. Methods: initialize(n, m, l = 32, h = 64, optimizer = SGD, loss = mse, lr = 0.003): Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. l (int): Length of memory for update step purposes. h (int): Default value 64. Hidden dimension of RNN. optimizer (class): optimizer choice loss (class): loss choice lr (float): learning rate for update predict(x) Description: Predict next value given observation Args: x (int/numpy.ndarray): Observation Returns: Predicted value for the next time-step update(y) Description: Updates parameters Args: y (int/numpy.ndarray): True value at current time-step Returns: None help() Description: Prints information about this class and its methods. Args: None Returns: None -------------------- *** -------------------- """