Source code for tigercontrol.problems.control.lstm_control

"""
Long-short term memory output
"""
import jax
import jax.numpy as np
import jax.experimental.stax as stax
import tigercontrol
from tigercontrol.utils.random import generate_key
from tigercontrol.problems.control import ControlProblem

[docs]class LSTM_Control(ControlProblem): """ Description: Produces outputs from a randomly initialized recurrent neural network. """
[docs] def __init__(self): self.initialized = False
def initialize(self, n, m, h=64): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series """ self.T = 0 self.initialized = True self.n, self.m, self.h = n, m, h glorot_init = stax.glorot() # returns a function that initializes weights self.W_hh = glorot_init(generate_key(), (4*h, h)) # maps h_t to gates self.W_xh = glorot_init(generate_key(), (4*h, n)) # maps x_t to gates self.b_h = np.zeros(4*h) self.b_h = jax.ops.index_update(self.b_h, jax.ops.index[h:2*h], np.ones(h)) # forget gate biased initialization self.W_out = glorot_init(generate_key(), (m, h)) # maps h_t to output self.cell = np.zeros(h) # long-term memory self.hid = np.zeros(h) # short-term memory def _step(x, hid, cell): sigmoid = lambda x: 1. / (1. + np.exp(-x)) # no JAX implementation of sigmoid it seems? gate = np.dot(self.W_hh, hid) + np.dot(self.W_xh, x) + self.b_h i, f, g, o = np.split(gate, 4) # order: input, forget, cell, output next_cell = sigmoid(f) * cell + sigmoid(i) * np.tanh(g) next_hid = sigmoid(o) * np.tanh(next_cell) y = np.dot(self.W_out, next_hid) return (next_hid, next_cell, y) self._step = jax.jit(_step) return np.dot(self.W_out, self.hid) def step(self, x): """ Description: Takes an input and produces the next output of the RNN. Args: x (numpy.ndarray): RNN input, an n-dimensional real-valued vector. Returns: The output of the RNN computed on the past l inputs, including the new x. """ assert self.initialized assert x.shape == (self.n,) self.T += 1 self.hid, self.cell, y = self._step(x, self.hid, self.cell) return y def hidden(self): """ Description: Return the hidden state of the RNN when computed on the last l inputs. Args: None Returns: h: The hidden state. """ assert self.initialized return (self.hid, self.cell) def help(self): """ Description: Prints information about this class and its methods. Args: None Returns: None """ print(LSTM_Control_help)
# string to print when calling help() method LSTM_Control_help = """ -------------------- *** -------------------- Id: RNN-Control-v0 Description: Produces outputs from a randomly initialized recurrent neural network. Methods: initialize(n, m, l=32, h=128, rnn=None) Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series step(x) Description: Takes an input and produces the next output of the RNN. Args: x (numpy.ndarray): RNN input, an n-dimensional real-valued vector. Returns: The output of the RNN computed on the past l inputs, including the new x. hidden() Description: Return the hidden state of the RNN when computed on the last l inputs. Args: None Returns: h: The hidden state. help() Description: Prints information about this class and its methods. Args: None Returns: None -------------------- *** -------------------- """