强化学习系列文章(三十):训练利器Gym Wrapper
强化学习系列文章(三十):训练利器Gym Wrapper在训练LunarLander环境的智能体算法时,学习到CleanRL的PPO代码,是我目前测试过训练速度最快的PPO版本。我认为主要贡献之一是采用了成熟的gym.wrapper技术,现总结这项技术的学习笔记。wrapper介绍主要分3类wrapper,分别是action,observation,reward。分别继承ActionWrapper
·
强化学习系列文章(三十):训练利器Gym Wrapper
在训练LunarLander环境的智能体算法时,学习到CleanRL的PPO代码,是我目前测试过训练速度最快的PPO版本。我认为主要贡献之一是采用了成熟的gym.wrapper
技术,现总结这项技术的学习笔记。
https://www.gymlibrary.ml/content/wrappers/
wrapper介绍
主要分3类wrapper,分别是action,observation,reward。分别继承ActionWrapper
、ObservationWrapper
、RewardWrapper
三个类,可以设计编写定制化的封装对象。
wrapper示例
import gym
from gym.wrappers import RescaleAction
base_env = gym.make("BipedalWalker-v3")
base_env.action_space
# Box([-1. -1. -1. -1.], [1. 1. 1. 1.], (4,), float32)
wrapped_env = RescaleAction(base_env, min_action=0, max_action=1)
wrapped_env.action_space
# Box([0. 0. 0. 0.], [1. 1. 1. 1.], (4,), float32)
解除wrapper
要想获得封装之前的环境,只需要调用封装后环境的env
属性:
wrapped_env.env
如果环境被多层封装,想要直接获得最底层的环境对象,则调用:
wrapped_env.unwrapped
wrapper总览
from gym.wrappers.monitor import Monitor
from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.filter_observation import FilterObservation
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.time_aware_observation import TimeAwareObservation
from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.flatten_observation import FlattenObservation
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.frame_stack import LazyFrames
from gym.wrappers.frame_stack import FrameStack
from gym.wrappers.transform_observation import TransformObservation
from gym.wrappers.transform_reward import TransformReward
from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.clip_action import ClipAction
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
常用Wrapper介绍
AtariPreprocessing
:
标准的Atari游戏预处理封装器,一劳永逸,非常好用。主要有如下的预处理函数。
* NoopReset: obtain initial state by taking random number of no-ops on reset. 游戏开始之初执行一定次数的nope action以增加游戏随机性。
* Frame skipping: 4 by default. 堆叠4帧画面作为1个state,提供time sequence信息。
* Max-pooling: most recent two observations. 没懂。
* Termination signal when a life is lost: turned off by default. Not recommended by Machado et al. (2018).
* Resize to a square image: 84x84 by default. 图像尺寸归一化。
* Grayscale observation: optional. 灰度化。
* Scale observation: optional.
Args:
env (Env): environment
noop_max (int): max number of no-ops
frame_skip (int): the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): if True, then step() returns done=True whenever a life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation is returned.
grayscale_newaxis (bool): if True and grayscale_obs=True, then a channel axis is added to grayscale observations to make them 3-dimensional.
scale_obs (bool): if True, then observation normalized in range [0,1] is returned. It also limits memory optimization benefits of FrameStack Wrapper.
ClipAction
:
class ClipAction(ActionWrapper):
r"""Clip the continuous action within the valid bound."""
def __init__(self, env):
assert isinstance(env.action_space, Box)
super(ClipAction, self).__init__(env)
def action(self, action):
return np.clip(action, self.action_space.low, self.action_space.high)
NormalizeObservation
:
统计过去历史step的observation,更新均值和方差。
class NormalizeObservation(gym.core.Wrapper):
def __init__(
self,
env,
epsilon=1e-8,
):
super(NormalizeObservation, self).__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
if self.is_vector_env:
self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
def step(self, action):
obs, rews, dones, infos = self.env.step(action)
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs, rews, dones, infos
def reset(self):
obs = self.env.reset()
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs
def normalize(self, obs):
self.obs_rms.update(obs)
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
NormalizeReward
:
class NormalizeReward(gym.core.Wrapper):
def __init__(
self,
env,
gamma=0.99,
epsilon=1e-8,
):
super(NormalizeReward, self).__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
self.return_rms = RunningMeanStd(shape=())
self.returns = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
def step(self, action):
obs, rews, dones, infos = self.env.step(action)
if not self.is_vector_env:
rews = np.array([rews])
self.returns = self.returns * self.gamma + rews
rews = self.normalize(rews)
self.returns[dones] = 0.0
if not self.is_vector_env:
rews = rews[0]
return obs, rews, dones, infos
def normalize(self, rews):
self.return_rms.update(self.returns)
return rews / np.sqrt(self.return_rms.var + self.epsilon)
TransformReward
:
自定义函数做reward变换,例如:
env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
class TransformReward(RewardWrapper):
r"""Transform the reward via an arbitrary function.
Example::
>>> import gym
>>> env = gym.make('CartPole-v1')
>>> env = TransformReward(env, lambda r: 0.01*r)
>>> env.reset()
>>> observation, reward, done, info = env.step(env.action_space.sample())
>>> reward
0.01
Args:
env (Env): environment
f (callable): a function that transforms the reward
"""
def __init__(self, env, f):
super(TransformReward, self).__init__(env)
assert callable(f)
self.f = f
def reward(self, reward):
return self.f(reward)
TransformObservation
:
自定义函数做observation变换,例如:
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
实例
def wrapper(env_id, seed):
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.ClipAction(env)
env = gym.wrappers.NormalizeObservation(env)
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
env = gym.wrappers.NormalizeReward(env)
env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
更多推荐
已为社区贡献2条内容
所有评论(0)