We have now released v0.3.0! Please use the latest version for the best experience.

Source code for omni.isaac.orbit.managers.termination_manager

# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Termination manager for computing done signals for a given world."""

from __future__ import annotations

import torch
from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING

from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import TerminationTermCfg

if TYPE_CHECKING:
    from omni.isaac.orbit.envs import RLTaskEnv


[docs]class TerminationManager(ManagerBase): """Manager for computing done signals for a given world. The termination manager computes the termination signal (also called dones) as a combination of termination terms. Each termination term is a function which takes the environment as an argument and returns a boolean tensor of shape (num_envs,). The termination manager computes the termination signal as the union (logical or) of all the termination terms. Following the `Gymnasium API <https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/>`_, the termination signal is computed as the logical OR of the following signals: * **Time-out**: This signal is set to true if the environment has ended after an externally defined condition (that is outside the scope of a MDP). For example, the environment may be terminated if the episode has timed out (i.e. reached max episode length). * **Terminated**: This signal is set to true if the environment has reached a terminal state defined by the environment. This state may correspond to task success, task failure, robot falling, etc. These signals can be individually accessed using the :attr:`time_outs` and :attr:`terminated` properties. The termination terms are parsed from a config class containing the manager's settings and each term's parameters. Each termination term should instantiate the :class:`TerminationTermCfg` class. The term's configuration :attr:`TerminationTermCfg.time_out` decides whether the term is a timeout or a termination term. """ _env: RLTaskEnv """The environment instance."""
[docs] def __init__(self, cfg: object, env: RLTaskEnv): """Initializes the termination manager. Args: cfg: The configuration object or dictionary (``dict[str, TerminationTermCfg]``). env: An environment object. """ super().__init__(cfg, env) # prepare extra info to store individual termination term information self._term_dones = dict() for term_name in self._term_names: self._term_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) # create buffer for managing termination per environment self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) self._terminated_buf = torch.zeros_like(self._truncated_buf)
def __str__(self) -> str: """Returns: A string representation for termination manager.""" msg = f"<TerminationManager> contains {len(self._term_names)} active terms.\n" # create table for term information table = PrettyTable() table.title = "Active Termination Terms" table.field_names = ["Index", "Name", "Time Out"] # set alignment of table columns table.align["Name"] = "l" # add info on each term for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)): table.add_row([index, name, term_cfg.time_out]) # convert table to string msg += table.get_string() msg += "\n" return msg """ Properties. """ @property def active_terms(self) -> list[str]: """Name of active termination terms.""" return self._term_names @property def dones(self) -> torch.Tensor: """The net termination signal. Shape is (num_envs,).""" return self._truncated_buf | self._terminated_buf @property def time_outs(self) -> torch.Tensor: """The timeout signal (reaching max episode length). Shape is (num_envs,). This signal is set to true if the environment has ended after an externally defined condition (that is outside the scope of a MDP). For example, the environment may be terminated if the episode has timed out (i.e. reached max episode length). """ return self._truncated_buf @property def terminated(self) -> torch.Tensor: """The terminated signal (reaching a terminal state). Shape is (num_envs,). This signal is set to true if the environment has reached a terminal state defined by the environment. This state may correspond to task success, task failure, robot falling, etc. """ return self._terminated_buf """ Operations. """
[docs] def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]: """Returns the episodic counts of individual termination terms. Args: env_ids: The environment ids. Defaults to None, in which case all environments are considered. Returns: Dictionary of episodic sum of individual reward terms. """ # resolve environment ids if env_ids is None: env_ids = slice(None) # add to episode dict extras = {} for key in self._term_dones.keys(): # store information extras["Episode Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item() # reset all the reward terms for term_cfg in self._class_term_cfgs: term_cfg.func.reset(env_ids=env_ids) # return logged information return extras
[docs] def compute(self) -> torch.Tensor: """Computes the termination signal as union of individual terms. This function calls each termination term managed by the class and performs a logical OR operation to compute the net termination signal. Returns: The combined termination signal of shape (num_envs,). """ # reset computation self._truncated_buf[:] = False self._terminated_buf[:] = False # iterate over all the termination terms for name, term_cfg in zip(self._term_names, self._term_cfgs): value = term_cfg.func(self._env, **term_cfg.params) # store timeout signal separately if term_cfg.time_out: self._truncated_buf |= value else: self._terminated_buf |= value # add to episode dones self._term_dones[name][:] = value # return combined termination signal return self._truncated_buf | self._terminated_buf
[docs] def get_term(self, name: str) -> torch.Tensor: """Returns the termination term with the specified name. Args: name: The name of the termination term. Returns: The corresponding termination term value. Shape is (num_envs,). """ return self._term_dones[name]
""" Operations - Term settings. """
[docs] def set_term_cfg(self, term_name: str, cfg: TerminationTermCfg): """Sets the configuration of the specified term into the manager. Args: term_name: The name of the termination term. cfg: The configuration for the termination term. Raises: ValueError: If the term name is not found. """ if term_name not in self._term_names: raise ValueError(f"Termination term '{term_name}' not found.") # set the configuration self._term_cfgs[self._term_names.index(term_name)] = cfg
[docs] def get_term_cfg(self, term_name: str) -> TerminationTermCfg: """Gets the configuration for the specified term. Args: term_name: The name of the termination term. Returns: The configuration of the termination term. Raises: ValueError: If the term name is not found. """ if term_name not in self._term_names: raise ValueError(f"Termination term '{term_name}' not found.") # return the configuration return self._term_cfgs[self._term_names.index(term_name)]
""" Helper functions. """ def _prepare_terms(self): """Prepares a list of termination functions.""" # parse remaining termination terms and decimate their information self._term_names: list[str] = list() self._term_cfgs: list[TerminationTermCfg] = list() self._class_term_cfgs: list[TerminationTermCfg] = list() # check if config is dict already if isinstance(self.cfg, dict): cfg_items = self.cfg.items() else: cfg_items = self.cfg.__dict__.items() # iterate over all the terms for term_name, term_cfg in cfg_items: # check for non config if term_cfg is None: continue # check for valid config type if not isinstance(term_cfg, TerminationTermCfg): raise TypeError( f"Configuration for the term '{term_name}' is not of type TerminationTermCfg." f" Received: '{type(term_cfg)}'." ) # resolve common parameters self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1) # add function to list self._term_names.append(term_name) self._term_cfgs.append(term_cfg) # check if the term is a class if isinstance(term_cfg.func, ManagerTermBase): self._class_term_cfgs.append(term_cfg)