We have now released v0.3.0! Please use the latest version for the best experience.

Source code for omni.isaac.orbit_tasks.utils.wrappers.rsl_rl.exporter

# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

import copy
import os
import torch


[docs]def export_policy_as_jit(actor_critic: object, path: str, filename="policy.pt"): """Export policy into a Torch JIT file. Args: actor_critic: The actor-critic torch module. path: The path to the saving directory. filename: The name of exported JIT file. Defaults to "policy.pt". Reference: https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L180 """ policy_exporter = _TorchPolicyExporter(actor_critic) policy_exporter.export(path, filename)
[docs]def export_policy_as_onnx(actor_critic: object, path: str, filename="policy.onnx", verbose=False): """Export policy into a Torch ONNX file. Args: actor_critic: The actor-critic torch module. path: The path to the saving directory. filename: The name of exported JIT file. Defaults to "policy.pt". verbose: Whether to print the model summary. Defaults to False. """ if not os.path.exists(path): os.makedirs(path, exist_ok=True) policy_exporter = _OnnxPolicyExporter(actor_critic, verbose) policy_exporter.export(path, filename)
""" Helper Classes - Private. """ class _TorchPolicyExporter(torch.nn.Module): """Exporter of actor-critic into JIT file. Reference: https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L193 """ def __init__(self, actor_critic): super().__init__() self.actor = copy.deepcopy(actor_critic.actor) self.is_recurrent = actor_critic.is_recurrent if self.is_recurrent: self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) self.rnn.cpu() self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.forward = self.forward_lstm self.reset = self.reset_memory def forward_lstm(self, x): x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state)) self.hidden_state[:] = h self.cell_state[:] = c x = x.squeeze(0) return self.actor(x) def forward(self, x): return self.actor(x) @torch.jit.export def reset(self): pass def reset_memory(self): self.hidden_state[:] = 0.0 self.cell_state[:] = 0.0 def export(self, path, filename): os.makedirs(path, exist_ok=True) path = os.path.join(path, filename) self.to("cpu") traced_script_module = torch.jit.script(self) traced_script_module.save(path) class _OnnxPolicyExporter(torch.nn.Module): """Exporter of actor-critic into ONNX file.""" def __init__(self, actor_critic, verbose=False): super().__init__() self.verbose = verbose self.actor = copy.deepcopy(actor_critic.actor) self.is_recurrent = actor_critic.is_recurrent if self.is_recurrent: self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) self.rnn.cpu() self.forward = self.forward_lstm def forward_lstm(self, x_in, h_in, c_in): x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in)) x = x.squeeze(0) return self.actor(x), h, c def forward(self, x): return self.actor(x) def export(self, path, filename): self.to("cpu") if self.is_recurrent: obs = torch.zeros(1, self.rnn.input_size) h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) actions, h_out, c_out = self(obs, h_in, c_in) torch.onnx.export( self, (obs, h_in, c_in), os.path.join(path, filename), export_params=True, opset_version=11, verbose=self.verbose, input_names=["obs", "h_in", "c_in"], output_names=["actions", "h_out", "c_out"], dynamic_axes={}, ) else: obs = torch.zeros(1, self.actor[0].in_features) torch.onnx.export( self, obs, os.path.join(path, filename), export_params=True, opset_version=11, verbose=self.verbose, input_names=["obs"], output_names=["actions"], dynamic_axes={}, )