dm-env

1.6 · active · verified Thu Apr 16

dm-env is a Python interface for Reinforcement Learning (RL) environments, providing a foundational API for interacting with environments. It is actively maintained by DeepMind and is currently at version 1.6. The library focuses on core components such as `Environment`, `TimeStep`, and `specs` for defining actions, observations, rewards, and discounts.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart defines a simple counting environment using `dm-env`'s `Environment` abstract base class. It showcases how to define action, observation, reward, and discount specifications using `dm_env.specs`, implement `_reset` and `_step` methods, and interact with the environment through `reset()` and `step()` calls. The example also highlights the `TimeStep` object and its `step_type` attribute for managing episode progression.

import numpy as np
from dm_env import Environment, TimeStep, specs, StepType

class SimpleCountingEnv(Environment):
    def __init__(self, max_count=5):
        self._max_count = max_count
        self._current_count = 0
        self._reset_next_step = True

    def discount_spec(self):
        return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name='discount')

    def observation_spec(self):
        return specs.BoundedArray(shape=(), dtype=int, minimum=0, maximum=self._max_count, name='count')

    def action_spec(self):
        return specs.BoundedArray(shape=(), dtype=int, minimum=0, maximum=1, name='action') # 0: no-op, 1: increment

    def reward_spec(self):
        return specs.Array(shape=(), dtype=float, name='reward')

    def _reset(self):
        self._current_count = 0
        self._reset_next_step = False
        return TimeStep(step_type=StepType.FIRST,
                        reward=None,
                        discount=None,
                        observation=np.asarray(self._current_count, dtype=int))

    def _step(self, action):
        if self._reset_next_step:
            return self._reset()

        if action == 1:
            self._current_count += 1

        if self._current_count >= self._max_count:
            self._reset_next_step = True
            return TimeStep(step_type=StepType.LAST,
                            reward=np.asarray(1.0, dtype=float),
                            discount=np.asarray(0.0, dtype=float),
                            observation=np.asarray(self._current_count, dtype=int))
        else:
            return TimeStep(step_type=StepType.MID,
                            reward=np.asarray(0.0, dtype=float),
                            discount=np.asarray(1.0, dtype=float),
                            observation=np.asarray(self._current_count, dtype=int))

    def reset(self):
        return self._reset()

    def step(self, action):
        return self._step(action)

# --- Example Usage ---
env = SimpleCountingEnv()

timestep = env.reset()
print(f"Initial: {timestep.observation}")

while not timestep.last():
    action = 1 # Always try to increment
    timestep = env.step(action)
    print(f"Step {env._current_count}: Obs={timestep.observation}, Reward={timestep.reward}, Type={timestep.step_type.name}")

# Demonstrating reset after LAST timestep
timestep = env.step(0) # Action is ignored here
print(f"After last, calling step (action ignored): {timestep.observation}")

view raw JSON →