We have now released v0.3.0! Please use the latest version for the best experience.

Source code for omni.isaac.orbit.managers.action_manager

# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Action manager for processing actions sent to the environment."""

from __future__ import annotations

import torch
from abc import abstractmethod
from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING

from omni.isaac.orbit.assets import AssetBase

from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ActionTermCfg

if TYPE_CHECKING:
    from omni.isaac.orbit.envs import BaseEnv


[docs]class ActionTerm(ManagerTermBase): """Base class for action terms. The action term is responsible for processing the raw actions sent to the environment and applying them to the asset managed by the term. The action term is comprised of two operations: * Processing of actions: This operation is performed once per **environment step** and is responsible for pre-processing the raw actions sent to the environment. * Applying actions: This operation is performed once per **simulation step** and is responsible for applying the processed actions to the asset managed by the term. """
[docs] def __init__(self, cfg: ActionTermCfg, env: BaseEnv): """Initialize the action term. Args: cfg: The configuration object. env: The environment instance. """ # call the base class constructor super().__init__(cfg, env) # parse config to obtain asset to which the term is applied self._asset: AssetBase = self._env.scene[self.cfg.asset_name]
""" Properties. """ @property @abstractmethod def action_dim(self) -> int: """Dimension of the action term.""" raise NotImplementedError @property @abstractmethod def raw_actions(self) -> torch.Tensor: """The input/raw actions sent to the term.""" raise NotImplementedError @property @abstractmethod def processed_actions(self) -> torch.Tensor: """The actions computed by the term after applying any processing.""" raise NotImplementedError """ Operations. """
[docs] @abstractmethod def process_actions(self, actions: torch.Tensor): """Processes the actions sent to the environment. Note: This function is called once per environment step by the manager. Args: actions: The actions to process. """ raise NotImplementedError
[docs] @abstractmethod def apply_actions(self): """Applies the actions to the asset managed by the term. Note: This is called at every simulation step by the manager. """ raise NotImplementedError
[docs]class ActionManager(ManagerBase): """Manager for processing and applying actions for a given world. The action manager handles the interpretation and application of user-defined actions on a given world. It is comprised of different action terms that decide the dimension of the expected actions. The action manager performs operations at two stages: * processing of actions: It splits the input actions to each term and performs any pre-processing needed. This should be called once at every environment step. * apply actions: This operation typically sets the processed actions into the assets in the scene (such as robots). It should be called before every simulation step. """
[docs] def __init__(self, cfg: object, env: BaseEnv): """Initialize the action manager. Args: cfg: The configuration object or dictionary (``dict[str, ActionTermCfg]``). env: The environment instance. """ super().__init__(cfg, env) # create buffers to store actions self._action = torch.zeros((self.num_envs, self.total_action_dim), device=self.device) self._prev_action = torch.zeros_like(self._action)
def __str__(self) -> str: """Returns: A string representation for action manager.""" msg = f"<ActionManager> contains {len(self._term_names)} active terms.\n" # create table for term information table = PrettyTable() table.title = f"Active Action Terms (shape: {self.total_action_dim})" table.field_names = ["Index", "Name", "Dimension"] # set alignment of table columns table.align["Name"] = "l" table.align["Dimension"] = "r" # add info on each term for index, (name, term) in enumerate(self._terms.items()): table.add_row([index, name, term.action_dim]) # convert table to string msg += table.get_string() msg += "\n" return msg """ Properties. """ @property def total_action_dim(self) -> int: """Total dimension of actions.""" return sum(self.action_term_dim) @property def active_terms(self) -> list[str]: """Name of active action terms.""" return self._term_names @property def action_term_dim(self) -> list[int]: """Shape of each action term.""" return [term.action_dim for term in self._terms.values()] @property def action(self) -> torch.Tensor: """The actions sent to the environment. Shape is (num_envs, total_action_dim).""" return self._action @property def prev_action(self) -> torch.Tensor: """The previous actions sent to the environment. Shape is (num_envs, total_action_dim).""" return self._prev_action """ Operations. """
[docs] def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]: """Resets the action history. Args: env_ids: The environment ids. Defaults to None, in which case all environments are considered. Returns: An empty dictionary. """ # resolve environment ids if env_ids is None: env_ids = slice(None) # reset the action history self._prev_action[env_ids] = 0.0 self._action[env_ids] = 0.0 # reset all action terms for term in self._terms.values(): term.reset(env_ids=env_ids) # nothing to log here return {}
[docs] def process_action(self, action: torch.Tensor): """Processes the actions sent to the environment. Note: This function should be called once per environment step. Args: action: The actions to process. """ # check if action dimension is valid if self.total_action_dim != action.shape[1]: raise ValueError(f"Invalid action shape, expected: {self.total_action_dim}, received: {action.shape[1]}.") # store the input actions self._prev_action[:] = self._action self._action[:] = action.to(self.device) # split the actions and apply to each tensor idx = 0 for term in self._terms.values(): term_actions = action[:, idx : idx + term.action_dim] term.process_actions(term_actions) idx += term.action_dim
[docs] def apply_action(self) -> None: """Applies the actions to the environment/simulation. Note: This should be called at every simulation step. """ for term in self._terms.values(): term.apply_actions()
[docs] def get_term(self, name: str) -> ActionTerm: """Returns the action term with the specified name. Args: name: The name of the action term. Returns: The action term with the specified name. """ return self._terms[name]
""" Helper functions. """ def _prepare_terms(self): """Prepares a list of action terms.""" # parse action terms from the config self._term_names: list[str] = list() self._terms: dict[str, ActionTerm] = dict() # check if config is dict already if isinstance(self.cfg, dict): cfg_items = self.cfg.items() else: cfg_items = self.cfg.__dict__.items() for term_name, term_cfg in cfg_items: # check if term config is None if term_cfg is None: continue # check valid type if not isinstance(term_cfg, ActionTermCfg): raise TypeError( f"Configuration for the term '{term_name}' is not of type ActionTermCfg." f" Received: '{type(term_cfg)}'." ) # create the action term term = term_cfg.class_type(term_cfg, self._env) # sanity check if term is valid type if not isinstance(term, ActionTerm): raise TypeError(f"Returned object for the term '{term_name}' is not of type ActionType.") # add term name and parameters self._term_names.append(term_name) self._terms[term_name] = term