diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9429312f..b84df1c8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -35,6 +35,7 @@ Please keep the lists sorted alphabetically. * Fabian Jenelten * Lorenzo Terenzi * Marko Bjelonic +* Markus Portugall * Matthijs van der Boon * Özhan Özen * Pascal Roth diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 2ce9fda6..d268f3e6 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -14,6 +14,7 @@ from rsl_rl.env import VecEnv from rsl_rl.models import MLPModel from rsl_rl.utils import check_nan, resolve_callable +from rsl_rl.utils.log_writer import LogWriter from rsl_rl.utils.logger import Logger @@ -23,7 +24,14 @@ class OnPolicyRunner: alg: PPO """The actor-critic algorithm.""" - def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device: str = "cpu") -> None: + def __init__( + self, + env: VecEnv, + train_cfg: dict, + log_dir: str | Logger | None = None, + device: str = "cpu", + writer: LogWriter | None = None, + ) -> None: """Construct the runner, algorithm, and logging stack.""" self.env = env self.cfg = train_cfg @@ -39,17 +47,21 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev alg_class: type[PPO] = resolve_callable(self.cfg["algorithm"]["class_name"]) # type: ignore self.alg = alg_class.construct_algorithm(obs, self.env, self.cfg, self.device) - # Create the logger - self.logger = Logger( - log_dir=log_dir, - cfg=self.cfg, - env_cfg=self.env.cfg, - num_envs=self.env.num_envs, - is_distributed=self.is_distributed, - gpu_world_size=self.gpu_world_size, - gpu_global_rank=self.gpu_global_rank, - device=self.device, - ) + # Use the provided logger or create a default one from the log_dir path + if isinstance(log_dir, Logger): + self.logger = log_dir + else: + self.logger = Logger( + log_dir=log_dir, + cfg=self.cfg, + env_cfg=self.env.cfg, + num_envs=self.env.num_envs, + is_distributed=self.is_distributed, + gpu_world_size=self.gpu_world_size, + gpu_global_rank=self.gpu_global_rank, + device=self.device, + writer=writer, + ) self.current_learning_iteration = 0 @@ -125,11 +137,11 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals ) # Save model - if self.logger.writer is not None and it % self.cfg["save_interval"] == 0: + if self.logger.is_logging and it % self.cfg["save_interval"] == 0: self.save(os.path.join(self.logger.log_dir, f"model_{it}.pt")) # type: ignore # Save the final model after training and stop the logging writer - if self.logger.writer is not None: + if self.logger.is_logging: self.save(os.path.join(self.logger.log_dir, f"model_{self.current_learning_iteration}.pt")) # type: ignore self.logger.stop_logging_writer() diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py index e473b912..dfa539ee 100644 --- a/rsl_rl/utils/__init__.py +++ b/rsl_rl/utils/__init__.py @@ -5,6 +5,8 @@ """Helper functions.""" +from .log_writer import LogWriter +from .logger import Logger from .utils import ( check_nan, compile_model, @@ -18,6 +20,8 @@ ) __all__ = [ + "LogWriter", + "Logger", "check_nan", "compile_model", "get_param", diff --git a/rsl_rl/utils/log_writer.py b/rsl_rl/utils/log_writer.py new file mode 100644 index 00000000..4d7aa962 --- /dev/null +++ b/rsl_rl/utils/log_writer.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021-2026, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pathlib +from abc import ABC, abstractmethod + + +class LogWriter(ABC): + """Abstract base class for rsl_rl logging backends. + + Subclass this to implement a custom logging backend. Only :meth:`add_scalar` + is required; all other methods are no-ops by default. + + Example:: + + from rsl_rl.utils import LogWriter + + class MyWriter(LogWriter): + def add_scalar(self, tag: str, scalar_value: float, global_step: int) -> None: + db.insert(tag=tag, value=scalar_value, step=global_step) + """ + + @abstractmethod + def add_scalar(self, tag: str, scalar_value: float, global_step: int) -> None: + """Log a scalar metric. + + Args: + tag: Metric name using ``"Group/name"`` convention (e.g. ``"Train/mean_reward"``). + scalar_value: The scalar value to record. + global_step: Training iteration used as the x-axis. + """ + + @abstractmethod + def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: + """Upload environment and training configuration. Called once at training start.""" + + @abstractmethod + def save_model(self, model_path: str, it: int) -> None: + """Upload or archive a model checkpoint.""" + + @abstractmethod + def save_file(self, path: str) -> None: + """Upload or archive an arbitrary file (e.g., git diff).""" + + @abstractmethod + def save_video(self, video: pathlib.Path, it: int) -> None: + """Upload a video file.""" + + @abstractmethod + def stop(self) -> None: + """Finalize and close the logging run.""" diff --git a/rsl_rl/utils/logger.py b/rsl_rl/utils/logger.py index cf025fd1..69637b84 100644 --- a/rsl_rl/utils/logger.py +++ b/rsl_rl/utils/logger.py @@ -15,6 +15,7 @@ from collections import deque import rsl_rl +from rsl_rl.utils.log_writer import LogWriter class Logger: @@ -30,6 +31,7 @@ def __init__( gpu_world_size: int, gpu_global_rank: int, device: str, + writer: LogWriter | None = None, ) -> None: """Initialize buffers and logging state for a training run.""" self.log_dir = log_dir @@ -42,6 +44,10 @@ def __init__( self.tot_timesteps = 0 self.tot_time = 0 + self._injected_writer = writer + self.writer: LogWriter | None = None + self.logger_type: str | None = None + # Create buffers self.ep_extras = [] self.rewbuffer = deque(maxlen=100) @@ -60,39 +66,51 @@ def __init__( # Note: We only log from the process with rank 0 (main process) self.disable_logs = is_distributed and gpu_global_rank != 0 + @property + def is_logging(self) -> bool: + """True if a writer is active and emitting metrics.""" + return self.writer is not None + def init_logging_writer(self) -> None: - """Initialize the logging writer, which can be either Tensorboard, W&B or Neptune and save the code state. + """Initialize the logging writer and save the code state. - If the writer is either W&B or Neptune, the configuration and code state are uploaded as well. + If a writer was injected via the constructor, it is used directly. Otherwise + a writer is constructed from ``cfg["logger"]`` (``"tensorboard"``, ``"wandb"``, + or ``"neptune"``). Configuration and git diffs are uploaded for writers that + inherit from :class:`~rsl_rl.utils.LogWriter`. """ if self.log_dir is not None and not self.disable_logs: - self.logger_type = self.cfg.get("logger", "tensorboard") - self.logger_type = self.logger_type.lower() - if self.logger_type == "neptune": - from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter - - self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) - elif self.logger_type == "wandb": - from rsl_rl.utils.wandb_utils import WandbSummaryWriter - - self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) - elif self.logger_type == "tensorboard": - from torch.utils.tensorboard import SummaryWriter - - self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) + if self._injected_writer is not None: + self.writer = self._injected_writer + self.logger_type = "custom" else: - raise ValueError("Logger type not found. Please choose 'wandb', 'neptune', or 'tensorboard'.") + self.logger_type = self.cfg.get("logger", "tensorboard") + self.logger_type = self.logger_type.lower() + if self.logger_type == "neptune": + from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter + + self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) + elif self.logger_type == "wandb": + from rsl_rl.utils.wandb_utils import WandbSummaryWriter + + self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) + elif self.logger_type == "tensorboard": + from torch.utils.tensorboard import SummaryWriter + + self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) + else: + raise ValueError("Logger type not found. Please choose 'wandb', 'neptune', or 'tensorboard'.") else: self.writer = None # Save code state files_to_upload = self._store_code_state() - # Upload configuration and code state to external logging service if applicable - if self.writer is not None and self.logger_type in ["wandb", "neptune"]: - self.writer.store_config(self.env_cfg, self.cfg) # type: ignore + # Upload configuration and code state for LogWriter subclasses (no-op for plain TensorBoard) + if isinstance(self.writer, LogWriter): + self.writer.store_config(self.env_cfg, self.cfg) for path in files_to_upload: - self.writer.save_file(path) # type: ignore + self.writer.save_file(path) def process_env_step( self, @@ -146,7 +164,7 @@ def log( ) -> None: """Log the training metrics to the logging service and print them to the console. - If videos are available, they are uploaded to the logging service (W&B) as well. + If videos are available, they are uploaded to the logging service as well. """ if self.writer is not None: collection_size = self.cfg["num_steps_per_env"] * self.num_envs * self.gpu_world_size @@ -200,13 +218,10 @@ def log( self.writer.add_scalar("Rnd/weight", rnd_weight, it) # type: ignore self.writer.add_scalar("Train/mean_reward", statistics.mean(self.rewbuffer), it) self.writer.add_scalar("Train/mean_episode_length", statistics.mean(self.lenbuffer), it) - if self.logger_type != "wandb": - self.writer.add_scalar( - "Train/mean_reward/time", statistics.mean(self.rewbuffer), int(self.tot_time) - ) - self.writer.add_scalar( - "Train/mean_episode_length/time", statistics.mean(self.lenbuffer), int(self.tot_time) - ) + self.writer.add_scalar("Train/mean_reward/time", statistics.mean(self.rewbuffer), int(self.tot_time)) + self.writer.add_scalar( + "Train/mean_episode_length/time", statistics.mean(self.lenbuffer), int(self.tot_time) + ) # Print to console log_string = f"""{"#" * width}\n""" @@ -255,23 +270,23 @@ def log( ) print(log_string) - # Upload available videos - if self.logger_type == "wandb": + # Upload available videos (no-op for writers that don't support it) + if isinstance(self.writer, LogWriter): for video in pathlib.Path(self.log_dir).rglob("*.mp4"): # type: ignore - self.writer.save_video(video, it) # type: ignore + self.writer.save_video(video, it) # Clear extras buffer self.ep_extras.clear() def save_model(self, path: str, it: int) -> None: """Save the model to external logging services if specified.""" - if self.writer is not None and self.logger_type in ["neptune", "wandb"]: - self.writer.save_model(path, it) # type: ignore + if isinstance(self.writer, LogWriter): + self.writer.save_model(path, it) def stop_logging_writer(self) -> None: """Stop the logging writer.""" - if self.writer is not None and self.logger_type in ["neptune", "wandb"]: - self.writer.stop() # type: ignore + if isinstance(self.writer, LogWriter): + self.writer.stop() def _store_code_state(self) -> list[str]: """Store the current git diff of the code repositories involved in the experiment.""" diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_utils.py index fb23b564..488eeeb6 100644 --- a/rsl_rl/utils/neptune_utils.py +++ b/rsl_rl/utils/neptune_utils.py @@ -10,13 +10,15 @@ from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter +from rsl_rl.utils.log_writer import LogWriter + try: import neptune except ModuleNotFoundError: raise ModuleNotFoundError("neptune-client is required to log to Neptune.") from None -class NeptuneSummaryWriter(SummaryWriter): +class NeptuneSummaryWriter(SummaryWriter, LogWriter): """Summary writer for Neptune.""" def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: @@ -70,6 +72,7 @@ def add_scalar( global_step: int | None = None, walltime: float | None = None, new_style: bool = False, + double_precision: bool = False, ) -> None: """Log a scalar to both TensorBoard and Neptune.""" super().add_scalar( @@ -78,6 +81,7 @@ def add_scalar( global_step=global_step, walltime=walltime, new_style=new_style, + double_precision=double_precision, ) self.run[self._map_path(tag)].log(scalar_value, step=global_step) diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 178eef69..f5e0fd61 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -11,13 +11,15 @@ from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter +from rsl_rl.utils.log_writer import LogWriter + try: import wandb except ModuleNotFoundError: raise ModuleNotFoundError("wandb package is required to log to Weights and Biases.") from None -class WandbSummaryWriter(SummaryWriter): +class WandbSummaryWriter(SummaryWriter, LogWriter): """Summary writer for W&B.""" def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: @@ -64,6 +66,7 @@ def add_scalar( global_step: int | None = None, walltime: float | None = None, new_style: bool = False, + double_precision: bool = False, ) -> None: """Log a scalar to both TensorBoard and W&B.""" super().add_scalar( @@ -72,6 +75,7 @@ def add_scalar( global_step=global_step, walltime=walltime, new_style=new_style, + double_precision=double_precision, ) wandb.log({tag: scalar_value}, step=global_step)