1212from stable_baselines3 .common .type_aliases import GymObs , GymStepReturn
1313
1414
15+ class StickyActionEnv (gym .Wrapper ):
16+ """
17+ Sticky action.
18+
19+ Paper: https://arxiv.org/abs/1709.06009
20+ Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
21+
22+ :param env: Environment to wrap
23+ :param action_repeat_probability: Probability of repeating the last action
24+ """
25+
26+ def __init__ (self , env : gym .Env , action_repeat_probability : float ) -> None :
27+ super ().__init__ (env )
28+ self .action_repeat_probability = action_repeat_probability
29+ assert env .unwrapped .get_action_meanings ()[0 ] == "NOOP"
30+
31+ def reset (self , ** kwargs ) -> GymObs :
32+ self ._sticky_action = 0 # NOOP
33+ return self .env .reset (** kwargs )
34+
35+ def step (self , action : int ) -> GymStepReturn :
36+ if self .np_random .random () >= self .action_repeat_probability :
37+ self ._sticky_action = action
38+ return self .env .step (self ._sticky_action )
39+
40+
1541class NoopResetEnv (gym .Wrapper ):
1642 """
1743 Sample initial states by taking random number of no-ops on reset.
1844 No-op is assumed to be action 0.
1945
20- :param env: the environment to wrap
21- :param noop_max: the maximum value of no-ops to run
46+ :param env: Environment to wrap
47+ :param noop_max: Maximum value of no-ops to run
2248 """
2349
2450 def __init__ (self , env : gym .Env , noop_max : int = 30 ) -> None :
@@ -47,7 +73,7 @@ class FireResetEnv(gym.Wrapper):
4773 """
4874 Take action on reset for environments that are fixed until firing.
4975
50- :param env: the environment to wrap
76+ :param env: Environment to wrap
5177 """
5278
5379 def __init__ (self , env : gym .Env ) -> None :
@@ -71,7 +97,7 @@ class EpisodicLifeEnv(gym.Wrapper):
7197 Make end-of-life == end-of-episode, but only reset on true game over.
7298 Done by DeepMind for the DQN and co. since it helps value estimation.
7399
74- :param env: the environment to wrap
100+ :param env: Environment to wrap
75101 """
76102
77103 def __init__ (self , env : gym .Env ) -> None :
@@ -120,9 +146,11 @@ def reset(self, **kwargs) -> np.ndarray:
120146class MaxAndSkipEnv (gym .Wrapper ):
121147 """
122148 Return only every ``skip``-th frame (frameskipping)
149+ and return the max between the two last frames.
123150
124- :param env: the environment
125- :param skip: number of ``skip``-th frame
151+ :param env: Environment to wrap
152+ :param skip: Number of ``skip``-th frame
153+ The same action will be taken ``skip`` times.
126154 """
127155
128156 def __init__ (self , env : gym .Env , skip : int = 4 ) -> None :
@@ -159,9 +187,9 @@ def step(self, action: int) -> GymStepReturn:
159187
160188class ClipRewardEnv (gym .RewardWrapper ):
161189 """
162- Clips the reward to {+1, 0, -1} by its sign.
190+ Clip the reward to {+1, 0, -1} by its sign.
163191
164- :param env: the environment
192+ :param env: Environment to wrap
165193 """
166194
167195 def __init__ (self , env : gym .Env ) -> None :
@@ -182,9 +210,9 @@ class WarpFrame(gym.ObservationWrapper):
182210 Convert to grayscale and warp frames to 84x84 (default)
183211 as done in the Nature paper and later work.
184212
185- :param env: the environment
186- :param width:
187- :param height:
213+ :param env: Environment to wrap
214+ :param width: New frame width
215+ :param height: New frame height
188216 """
189217
190218 def __init__ (self , env : gym .Env , width : int = 84 , height : int = 84 ) -> None :
@@ -213,20 +241,29 @@ class AtariWrapper(gym.Wrapper):
213241
214242 Specifically:
215243
216- * NoopReset : obtain initial state by taking random number of no-ops on reset.
244+ * Noop reset : obtain initial state by taking random number of no-ops on reset.
217245 * Frame skipping: 4 by default
218246 * Max-pooling: most recent two observations
219247 * Termination signal when a life is lost.
220248 * Resize to a square image: 84x84 by default
221249 * Grayscale observation
222250 * Clip reward to {-1, 0, 1}
251+ * Sticky actions: disabled by default
252+
253+ See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
254+ for a visual explanation.
255+
256+ .. warning::
257+ Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.
223258
224- :param env: gym environment
225- :param noop_max: max number of no-ops
226- :param frame_skip: the frequency at which the agent experiences the game.
227- :param screen_size: resize Atari frame
228- :param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost.
259+ :param env: Environment to wrap
260+ :param noop_max: Max number of no-ops
261+ :param frame_skip: Frequency at which the agent experiences the game.
262+ This correspond to repeating the action ``frame_skip`` times.
263+ :param screen_size: Resize Atari frame
264+ :param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost.
229265 :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
266+ :param action_repeat_probability: Probability of repeating the last action
230267 """
231268
232269 def __init__ (
@@ -237,9 +274,15 @@ def __init__(
237274 screen_size : int = 84 ,
238275 terminal_on_life_loss : bool = True ,
239276 clip_reward : bool = True ,
277+ action_repeat_probability : float = 0.0 ,
240278 ) -> None :
241- env = NoopResetEnv (env , noop_max = noop_max )
242- env = MaxAndSkipEnv (env , skip = frame_skip )
279+ if action_repeat_probability > 0.0 :
280+ env = StickyActionEnv (env , action_repeat_probability )
281+ if noop_max > 0 :
282+ env = NoopResetEnv (env , noop_max = noop_max )
283+ # frame_skip=1 is the same as no frame-skip (action repeat)
284+ if frame_skip > 1 :
285+ env = MaxAndSkipEnv (env , skip = frame_skip )
243286 if terminal_on_life_loss :
244287 env = EpisodicLifeEnv (env )
245288 if "FIRE" in env .unwrapped .get_action_meanings ():
0 commit comments