You can access the local variables available to the logger callback using self.locals
. Any variables exposed in your custom environment will be accessible via locals dict.
The example below shows how to access a key in a custom dictionary called my_custom_info_dict
in vectorized environments.
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
from stable_baselines3.common.vec_env import SubprocVecEnv
def make_env(env):
"""
See https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
for more details on vectorized environments
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environment you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
:return: (Callable)
"""
def _init():
return env
return _init
class SummaryWriterCallback(BaseCallback):
'''
Snippet skeleton from Stable baselines3 documentation here:
https://stable-baselines3.readthedocs.io/en/master/guide/tensorboard.html#directly-accessing-the-summary-writer
'''
def _on_training_start(self):
self._log_freq = 10 # log every 10 calls
output_formats = self.logger.output_formats
# Save reference to tensorboard formatter object
# note: the failure case (not formatter found) is not handled here, should be done with try/except.
self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
def _on_step(self) -> bool:
'''
Log my_custom_reward every _log_freq(th) to tensorboard for each environment
'''
if self.n_calls % self._log_freq == 0:
rewards = self.locals['my_custom_info_dict']['my_custom_reward']
for i in range(self.locals['env'].num_envs):
self.tb_formatter.writer.add_scalar("rewards/env #{}".format(i+1),
rewards[i],
self.n_calls)
if __name__ == "__main__":
env_id = "CartPole-v1"
envs = SubprocVecEnv([make_env(env_id, i) for i in range(4)]) # 4 environments
model = SAC("MlpPolicy", envs, tensorboard_log="/tmp/sac/",
verbose=1)
model.learn(50000, callback=TensorboardCallback())