强化学习系列文章(三十):训练利器Gym Wrapper

在训练LunarLander环境的智能体算法时,学习到CleanRL的PPO代码,是我目前测试过训练速度最快的PPO版本。我认为主要贡献之一是采用了成熟的gym.wrapper技术,现总结这项技术的学习笔记。
https://www.gymlibrary.ml/content/wrappers/

wrapper介绍

主要分3类wrapper,分别是action,observation,reward。分别继承ActionWrapperObservationWrapperRewardWrapper三个类,可以设计编写定制化的封装对象。

wrapper示例

import gym
from gym.wrappers import RescaleAction
base_env = gym.make("BipedalWalker-v3")
base_env.action_space
# Box([-1. -1. -1. -1.], [1. 1. 1. 1.], (4,), float32)
wrapped_env = RescaleAction(base_env, min_action=0, max_action=1)
wrapped_env.action_space
# Box([0. 0. 0. 0.], [1. 1. 1. 1.], (4,), float32)

解除wrapper

要想获得封装之前的环境,只需要调用封装后环境的env属性:

wrapped_env.env

如果环境被多层封装,想要直接获得最底层的环境对象,则调用:

wrapped_env.unwrapped

wrapper总览

from gym.wrappers.monitor import Monitor
from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.filter_observation import FilterObservation
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.time_aware_observation import TimeAwareObservation
from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.flatten_observation import FlattenObservation
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.frame_stack import LazyFrames
from gym.wrappers.frame_stack import FrameStack
from gym.wrappers.transform_observation import TransformObservation
from gym.wrappers.transform_reward import TransformReward
from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.clip_action import ClipAction
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule

常用Wrapper介绍

AtariPreprocessing:

标准的Atari游戏预处理封装器,一劳永逸,非常好用。主要有如下的预处理函数。

* NoopReset: obtain initial state by taking random number of no-ops on reset. 游戏开始之初执行一定次数的nope action以增加游戏随机性。
* Frame skipping: 4 by default. 堆叠4帧画面作为1个state,提供time sequence信息。
* Max-pooling: most recent two observations. 没懂。
* Termination signal when a life is lost: turned off by default. Not recommended by Machado et al. (2018).
* Resize to a square image: 84x84 by default. 图像尺寸归一化。
* Grayscale observation: optional. 灰度化。
* Scale observation: optional.
Args:
    env (Env): environment
    noop_max (int): max number of no-ops
    frame_skip (int): the frequency at which the agent experiences the game.
    screen_size (int): resize Atari frame
    terminal_on_life_loss (bool): if True, then step() returns done=True whenever a life is lost.
    grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation is returned.
    grayscale_newaxis (bool): if True and grayscale_obs=True, then a channel axis is added to grayscale observations to make them 3-dimensional.
    scale_obs (bool): if True, then observation normalized in range [0,1] is returned. It also limits memory optimization benefits of FrameStack Wrapper.
ClipAction:
class ClipAction(ActionWrapper):
    r"""Clip the continuous action within the valid bound."""

    def __init__(self, env):
        assert isinstance(env.action_space, Box)
        super(ClipAction, self).__init__(env)

    def action(self, action):
        return np.clip(action, self.action_space.low, self.action_space.high)
NormalizeObservation:

统计过去历史step的observation,更新均值和方差。

class NormalizeObservation(gym.core.Wrapper):
    def __init__(
        self,
        env,
        epsilon=1e-8,
    ):
        super(NormalizeObservation, self).__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.is_vector_env = getattr(env, "is_vector_env", False)
        if self.is_vector_env:
            self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
        else:
            self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, dones, infos = self.env.step(action)
        if self.is_vector_env:
            obs = self.normalize(obs)
        else:
            obs = self.normalize(np.array([obs]))[0]
        return obs, rews, dones, infos

    def reset(self):
        obs = self.env.reset()
        if self.is_vector_env:
            obs = self.normalize(obs)
        else:
            obs = self.normalize(np.array([obs]))[0]
        return obs

    def normalize(self, obs):
        self.obs_rms.update(obs)
        return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
NormalizeReward:
class NormalizeReward(gym.core.Wrapper):
    def __init__(
        self,
        env,
        gamma=0.99,
        epsilon=1e-8,
    ):
        super(NormalizeReward, self).__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.is_vector_env = getattr(env, "is_vector_env", False)
        self.return_rms = RunningMeanStd(shape=())
        self.returns = np.zeros(self.num_envs)
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, dones, infos = self.env.step(action)
        if not self.is_vector_env:
            rews = np.array([rews])
        self.returns = self.returns * self.gamma + rews
        rews = self.normalize(rews)
        self.returns[dones] = 0.0
        if not self.is_vector_env:
            rews = rews[0]
        return obs, rews, dones, infos

    def normalize(self, rews):
        self.return_rms.update(self.returns)
        return rews / np.sqrt(self.return_rms.var + self.epsilon)
TransformReward:

自定义函数做reward变换,例如:

env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
class TransformReward(RewardWrapper):
    r"""Transform the reward via an arbitrary function.
    Example::
        >>> import gym
        >>> env = gym.make('CartPole-v1')
        >>> env = TransformReward(env, lambda r: 0.01*r)
        >>> env.reset()
        >>> observation, reward, done, info = env.step(env.action_space.sample())
        >>> reward
        0.01
    Args:
        env (Env): environment
        f (callable): a function that transforms the reward
    """

    def __init__(self, env, f):
        super(TransformReward, self).__init__(env)
        assert callable(f)
        self.f = f
    def reward(self, reward):
        return self.f(reward)
TransformObservation:

自定义函数做observation变换,例如:

env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))

实例

def wrapper(env_id, seed):
    env = gym.make(env_id)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = gym.wrappers.ClipAction(env)
    env = gym.wrappers.NormalizeObservation(env)
    env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
    env = gym.wrappers.NormalizeReward(env)
    env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐