"""
Non-PyBullet implementation of CartPole
"""
import jax
import jax.numpy as np
import jax.random as random
import tigercontrol
from tigercontrol.utils import generate_key
from tigercontrol.problems.control import ControlProblem
# necessary for rendering
from gym.envs.classic_control import rendering
[docs]class CartPole(ControlProblem):
"""
Description:
A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
The pendulum starts upright, and the goal is to prevent it from falling over by increasing
and reducing the cart's velocity.
"""
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second' : 50
}
[docs] def __init__(self):
self.initialized = False
self.gravity = 9.8
self.masscart = 1.0
self.masspole = 0.1
self.total_mass = (self.masspole + self.masscart)
self.length = 0.5 # actually half the pole's length
self.polemass_length = (self.masspole * self.length)
self.force_mag = 10.0
self.tau = 0.02 # seconds between state updates
# self.kinematics_integrator = 'euler' # use euler by default
# Angle at which to fail the episode
self.theta_threshold_radians = 12 * 2 * np.pi / 360
self.x_threshold = 2.4
self.action_space = (1,)
self.observation_space = (4,)
self.viewer = None
self.state = None
self.steps_beyond_done = None
@jax.jit
def dynamics(x_0, u):
x, x_dot, theta, theta_dot = np.split(x_0, 4)
force = self.force_mag * np.clip(u, -1.0, 1.0)[0] # iLQR may struggle with clipping due to lack of gradient
costh = np.cos(theta)
sinth = np.sin(theta)
temp = (force + self.polemass_length * theta_dot * theta_dot * sinth) / self.total_mass
thetaacc = (self.gravity*sinth - costh*temp) / (self.length * (4.0/3.0 - self.masspole*costh*costh / self.total_mass))
xacc = temp - self.polemass_length * thetaacc * costh / self.total_mass
x = x + self.tau * x_dot # use euler integration by default
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
state = np.concatenate((x, x_dot, theta, theta_dot))
return state
self.dynamics = dynamics
def initialize(self):
self.initialized = True
return self.reset()
def step(self, action):
assert self.initialized
if type(action) == np.ndarray:
action = action[0]
self.state = self.dynamics(self.state, action)
x, _ , theta, _ = np.split(self.state, 4)
done = x < -self.x_threshold \
or x > self.x_threshold \
or theta < -self.theta_threshold_radians \
or theta > self.theta_threshold_radians
done = bool(done)
if not done:
reward = 1.0
elif self.steps_beyond_done is None:
self.steps_beyond_done = 0 # Pole just fell!
reward = 1.0
else:
if self.steps_beyond_done == 0:
print("Warning: step() called after problem is 'done'.")
return self.state, reward, done, {}
def reset(self):
self.state = random.uniform(generate_key(),shape=(4,), minval=-0.05, maxval=0.05)
self.steps_beyond_done = None
self.state = np.array([0.0, 0.03, 0.03, 0.03])
return self.state
def render(self, mode='human'):
screen_width = 600
screen_height = 400
world_width = self.x_threshold*2
scale = screen_width/world_width
carty = 100 # TOP OF CART
polewidth = 10.0
polelen = scale * (2 * self.length)
cartwidth = 50.0
cartheight = 30.0
if self.viewer is None:
self.viewer = rendering.Viewer(screen_width, screen_height)
l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
axleoffset =cartheight/4.0
cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
self.carttrans = rendering.Transform()
cart.add_attr(self.carttrans)
self.viewer.add_geom(cart)
l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2
pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
pole.set_color(.8,.6,.4)
self.poletrans = rendering.Transform(translation=(0, axleoffset))
pole.add_attr(self.poletrans)
pole.add_attr(self.carttrans)
self.viewer.add_geom(pole)
self.axle = rendering.make_circle(polewidth/2)
self.axle.add_attr(self.poletrans)
self.axle.add_attr(self.carttrans)
self.axle.set_color(.5,.5,.8)
self.viewer.add_geom(self.axle)
self.track = rendering.Line((0,carty), (screen_width,carty))
self.track.set_color(0,0,0)
self.viewer.add_geom(self.track)
self._pole_geom = pole
if self.state is None: return None
# Edit the pole polygon vertex
pole = self._pole_geom
l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2
pole.v = [(l,b), (l,t), (r,t), (r,b)]
x = self.state
cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART
self.carttrans.set_translation(cartx, carty)
self.poletrans.set_rotation(-x[2])
return self.viewer.render(return_rgb_array = mode=='rgb_array')
def close(self):
if self.viewer:
self.viewer.close()
self.viewer = None