Source code for tigercontrol.problems.control.pendulum

"""
Non-PyBullet implementation of Pendulum
"""
import jax
import jax.numpy as np
import jax.random as random

import os
import tigercontrol
from tigercontrol.utils import generate_key, get_tigercontrol_dir
from tigercontrol.problems.control import ControlProblem

# necessary for rendering
from gym.envs.classic_control import rendering


[docs]class Pendulum(ControlProblem): metadata = { 'render.modes' : ['human', 'rgb_array'], 'video.frames_per_second' : 30 }
[docs] def __init__(self, g=10.0): self.initialized = False self.max_speed=8 self.max_torque=2. self.dt=.05 self.g = g self.viewer = None self.action_space = (1,) self.observation_space = (2,) @jax.jit def angle_normalize(x): x = np.where(x > np.pi, x - 2*np.pi, x) x = np.where(x < -np.pi, x + 2*np.pi, x) return x self.angle_normalize = angle_normalize @jax.jit def dynamics(x, u): th, thdot = x g = self.g m = 1. l = 1. dt = self.dt u = np.clip(u, -self.max_torque, self.max_torque)[0] newthdot = thdot + (-3*g/(2*l) * np.sin(th + np.pi) + 3./(m*l**2)*u) * dt newth = self.angle_normalize(th + newthdot*dt) newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) return np.array([newth, newthdot]) self.dynamics = dynamics
def initialize(self): self.initialized = True return self.reset() def step(self,u): self.last_u = np.clip(u, -self.max_torque, self.max_torque)[0] # for rendering state = self.dynamics(self.state, u) costs = self.angle_normalize(state[0])**2 + .1*state[1]**2 + .001*(u**2) self.state = state return self.state, -costs, False, {} def reset(self): theta = random.uniform(generate_key(), minval=-np.pi, maxval=np.pi) thdot = random.uniform(generate_key(), minval=-1., maxval=1.) self.state = np.array([theta, thdot]) self.last_u = 0.0 return self.state def render(self, mode='human'): if self.viewer is None: self.viewer = rendering.Viewer(500,500) self.viewer.set_bounds(-2.2,2.2,-2.2,2.2) rod = rendering.make_capsule(1, .2) rod.set_color(.8, .3, .3) self.pole_transform = rendering.Transform() rod.add_attr(self.pole_transform) self.viewer.add_geom(rod) axle = rendering.make_circle(.05) axle.set_color(0,0,0) self.viewer.add_geom(axle) fname = os.path.join(get_tigercontrol_dir(), "problems/control/assets/clockwise.png") self.img = rendering.Image(fname, 1., 1.) self.imgtrans = rendering.Transform() self.img.add_attr(self.imgtrans) self.viewer.add_onetime(self.img) self.pole_transform.set_rotation(self.state[0] + np.pi/2) if self.last_u: self.imgtrans.scale = (-self.last_u/2, np.abs(self.last_u)/2) return self.viewer.render(return_rgb_array = mode=='rgb_array') def close(self): if self.viewer: self.viewer.close() self.viewer = None