Source code for tigercontrol.problems.control.cartpole

"""
Non-PyBullet implementation of CartPole
"""
import jax
import jax.numpy as np
import jax.random as random

import tigercontrol
from tigercontrol.utils import generate_key
from tigercontrol.problems.control import ControlProblem

# necessary for rendering
from gym.envs.classic_control import rendering


[docs]class CartPole(ControlProblem): """ Description: A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The pendulum starts upright, and the goal is to prevent it from falling over by increasing and reducing the cart's velocity. """ metadata = { 'render.modes': ['human', 'rgb_array'], 'video.frames_per_second' : 50 }
[docs] def __init__(self): self.initialized = False self.gravity = 9.8 self.masscart = 1.0 self.masspole = 0.1 self.total_mass = (self.masspole + self.masscart) self.length = 0.5 # actually half the pole's length self.polemass_length = (self.masspole * self.length) self.force_mag = 10.0 self.tau = 0.02 # seconds between state updates # self.kinematics_integrator = 'euler' # use euler by default # Angle at which to fail the episode self.theta_threshold_radians = 12 * 2 * np.pi / 360 self.x_threshold = 2.4 self.action_space = (1,) self.observation_space = (4,) self.viewer = None self.state = None self.steps_beyond_done = None @jax.jit def dynamics(x_0, u): x, x_dot, theta, theta_dot = np.split(x_0, 4) force = self.force_mag * np.clip(u, -1.0, 1.0)[0] # iLQR may struggle with clipping due to lack of gradient costh = np.cos(theta) sinth = np.sin(theta) temp = (force + self.polemass_length * theta_dot * theta_dot * sinth) / self.total_mass thetaacc = (self.gravity*sinth - costh*temp) / (self.length * (4.0/3.0 - self.masspole*costh*costh / self.total_mass)) xacc = temp - self.polemass_length * thetaacc * costh / self.total_mass x = x + self.tau * x_dot # use euler integration by default x_dot = x_dot + self.tau * xacc theta = theta + self.tau * theta_dot theta_dot = theta_dot + self.tau * thetaacc state = np.concatenate((x, x_dot, theta, theta_dot)) return state self.dynamics = dynamics
def initialize(self): self.initialized = True return self.reset() def step(self, action): assert self.initialized if type(action) == np.ndarray: action = action[0] self.state = self.dynamics(self.state, action) x, _ , theta, _ = np.split(self.state, 4) done = x < -self.x_threshold \ or x > self.x_threshold \ or theta < -self.theta_threshold_radians \ or theta > self.theta_threshold_radians done = bool(done) if not done: reward = 1.0 elif self.steps_beyond_done is None: self.steps_beyond_done = 0 # Pole just fell! reward = 1.0 else: if self.steps_beyond_done == 0: print("Warning: step() called after problem is 'done'.") return self.state, reward, done, {} def reset(self): self.state = random.uniform(generate_key(),shape=(4,), minval=-0.05, maxval=0.05) self.steps_beyond_done = None self.state = np.array([0.0, 0.03, 0.03, 0.03]) return self.state def render(self, mode='human'): screen_width = 600 screen_height = 400 world_width = self.x_threshold*2 scale = screen_width/world_width carty = 100 # TOP OF CART polewidth = 10.0 polelen = scale * (2 * self.length) cartwidth = 50.0 cartheight = 30.0 if self.viewer is None: self.viewer = rendering.Viewer(screen_width, screen_height) l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2 axleoffset =cartheight/4.0 cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) self.carttrans = rendering.Transform() cart.add_attr(self.carttrans) self.viewer.add_geom(cart) l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2 pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) pole.set_color(.8,.6,.4) self.poletrans = rendering.Transform(translation=(0, axleoffset)) pole.add_attr(self.poletrans) pole.add_attr(self.carttrans) self.viewer.add_geom(pole) self.axle = rendering.make_circle(polewidth/2) self.axle.add_attr(self.poletrans) self.axle.add_attr(self.carttrans) self.axle.set_color(.5,.5,.8) self.viewer.add_geom(self.axle) self.track = rendering.Line((0,carty), (screen_width,carty)) self.track.set_color(0,0,0) self.viewer.add_geom(self.track) self._pole_geom = pole if self.state is None: return None # Edit the pole polygon vertex pole = self._pole_geom l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2 pole.v = [(l,b), (l,t), (r,t), (r,b)] x = self.state cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART self.carttrans.set_translation(cartx, carty) self.poletrans.set_rotation(-x[2]) return self.viewer.render(return_rgb_array = mode=='rgb_array') def close(self): if self.viewer: self.viewer.close() self.viewer = None