Source code for tigercontrol.problems.time_series.rnn_time_series

"""
Recurrent neural network output
"""
import jax
import jax.numpy as np
import jax.random as random
import jax.experimental.stax as stax
import tigercontrol
from tigercontrol.utils.random import generate_key
from tigercontrol.problems.control import ControlProblem

[docs]class RNN_TimeSeries(ControlProblem): """ Description: Produces outputs from a randomly initialized recurrent neural network. """ compatibles = set(['TimeSeries'])
[docs] def __init__(self): self.initialized = False
def initialize(self, n, m, h=64): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series """ self.T = 0 self.initialized = True self.has_regressors = True self.n, self.m, self.h = n, m, h glorot_init = stax.glorot() # returns a function that initializes weights self.W_h = glorot_init(generate_key(), (h, h)) self.W_x = glorot_init(generate_key(), (h, n)) self.W_out = glorot_init(generate_key(), (m, h)) self.b_h = np.zeros(h) self.hid = np.zeros(h) def _step(x, hid): next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h) y = np.dot(self.W_out, next_hid) return (next_hid, y) self._step = jax.jit(_step) return self.step() def step(self): """ Description: Takes an input and produces the next output of the RNN. Args: x (numpy.ndarray): RNN input, an n-dimensional real-valued vector. Returns: The output of the RNN computed on the past l inputs, including the new x. """ assert self.initialized self.T += 1 x = random.normal(generate_key(), shape=(self.n,)) self.hid, y = self._step(x, self.hid) return x, y def hidden(self): """ Description: Return the hidden state of the RNN when computed on the last l inputs. Args: None Returns: h: The hidden state. """ assert self.initialized return self.hid def help(self): """ Description: Prints information about this class and its methods. Args: None Returns: None """ print(RNN_TimeSeries_help)
# string to print when calling help() method RNN_TimeSeries_help = """ -------------------- *** -------------------- Id: RNN-Control-v0 Description: Produces outputs from a randomly initialized recurrent neural network. Methods: initialize(n, m, l=32, h=128, rnn=None) Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series step(x) Description: Takes an input and produces the next output of the RNN. Args: x (numpy.ndarray): RNN input, an n-dimensional real-valued vector. Returns: The output of the RNN computed on the past l inputs, including the new x. hidden() Description: Return the hidden state of the RNN when computed on the last l inputs. Args: None Returns: h: The hidden state. help() Description: Prints information about this class and its methods. Args: None Returns: None -------------------- *** -------------------- """