diff --git a/CITATION.cff b/CITATION.cff index d00bd22d..36877a74 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,7 +3,7 @@ title: "RSL-RL: A Learning Library for Robotics Research" message: "If you use this work, please cite the following paper." repository-code: "https://github.com/leggedrobotics/rsl_rl" license: BSD-3-Clause -version: 5.3.0 +version: 5.4.0 type: software authors: - family-names: Schwarke diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9429312f..b84df1c8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -35,6 +35,7 @@ Please keep the lists sorted alphabetically. * Fabian Jenelten * Lorenzo Terenzi * Marko Bjelonic +* Markus Portugall * Matthijs van der Boon * Özhan Özen * Pascal Roth diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 4c2bd4fe..56bd3fb3 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -9,29 +9,34 @@ Logger :undoc-members: -Utils ------ +Log Writer +---------- -.. automodule:: rsl_rl.utils.utils +.. automodule:: rsl_rl.utils.log_writer :members: :undoc-members: -Wandb Utils ------------ +Wandb Log Writer +---------------- -.. automodule:: rsl_rl.utils.wandb_utils +.. automodule:: rsl_rl.utils.wandb_log_writer :members: :undoc-members: -Neptune Utils -------------- +Neptune Log Writer +------------------ -.. automodule:: rsl_rl.utils.neptune_utils +.. automodule:: rsl_rl.utils.neptune_log_writer :members: :undoc-members: +Utils +----- +.. automodule:: rsl_rl.utils.utils + :members: + :undoc-members: diff --git a/docs/guide/configuration.rst b/docs/guide/configuration.rst index 3c3e8308..2ee548ec 100644 --- a/docs/guide/configuration.rst +++ b/docs/guide/configuration.rst @@ -34,8 +34,10 @@ Runner Configuration Currently, RSL-RL implements two runner classes: :class:`~rsl_rl.runners.on_policy_runner.OnPolicyRunner` and -:class:`~rsl_rl.runners.distillation_runner.DistillationRunner`. The -:class:`~rsl_rl.runners.on_policy_runner.OnPolicyRunner` is configured as follows: +:class:`~rsl_rl.runners.distillation_runner.DistillationRunner`, which are configured as follows. + +OnPolicyRunner +^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 @@ -45,31 +47,19 @@ Currently, RSL-RL implements two runner classes: - Type - Default - Description - * - ``num_steps_per_env`` - - int - - required - - Number of environment steps collected per iteration. * - ``obs_groups`` - dict[str, list[str]] - required - Mapping from observation sets to observation groups coming from the environment. See :ref:`here ` for more details. + * - ``num_steps_per_env`` + - int + - required + - Number of environment steps collected per iteration. * - ``save_interval`` - int - required - Number of iterations between checkpoints. - * - ``logger`` - - str - - ``"tensorboard"`` - - Logging service to use. Valid values: ``"tensorboard"``, ``"wandb"``, ``"neptune"``. - * - ``wandb_project`` - - str - - required for W&B - - W&B project name used by the W&B writer. - * - ``neptune_project`` - - str - - required for Neptune - - Neptune project name used by the Neptune writer. * - ``run_name`` - str - missing @@ -83,6 +73,19 @@ Currently, RSL-RL implements two runner classes: - ``None`` - Compile mode for the PyTorch models to accelerate training. Valid values: ``None``, ``"default"``, ``"max-autotune-no-cudagraphs"``. + * - ``logger`` + - str | dict + - ``"tensorboard"`` + - Logging writer configuration. The plain strings ``"wandb"`` and ``"neptune"`` are + still accepted but deprecated. + * - ``wandb_project`` + - str + - -- + - Deprecated. Pass ``project_name`` inside the ``logger`` configuration instead. + * - ``neptune_project`` + - str + - -- + - Deprecated. Pass ``project_name`` inside the ``logger`` configuration instead. * - ``algorithm`` - dict - required @@ -96,6 +99,9 @@ Currently, RSL-RL implements two runner classes: - required - Critic model configuration. +DistillationRunner +^^^^^^^^^^^^^^^^^^ + For the :class:`~rsl_rl.runners.distillation_runner.DistillationRunner`, the ``actor`` and ``critic`` keys are simply replaced by ``student`` and ``teacher`` keys, respectively: @@ -120,6 +126,31 @@ replaced by ``student`` and ``teacher`` keys, respectively: - required - Teacher model configuration. +Logger +^^^^^^ +The ``logger`` key of the runner configuration defines the logging writer used to log training metrics and other +information during training. RSL-RL supports TensorBoard, Weights & Biases, and Neptune out of the box. While +TensorBoard does not require any configuration and is set using the plain string ``"tensorboard"``, the other logging +backends are configured by passing a dictionary with the following keys: + +.. list-table:: + :header-rows: 1 + :class: no-wrap-type-column + + * - Key + - Type + - Default + - Description + * - ``class_name`` + - str + - required + - Logger class name. Valid values: ``"WandbLogWriter"``, ``"NeptuneLogWriter"``. + * - ``project_name`` + - str + - required + - Name of the project. + + Algorithm Configuration ----------------------- diff --git a/docs/guide/overview.rst b/docs/guide/overview.rst index bfb163ab..530154aa 100644 --- a/docs/guide/overview.rst +++ b/docs/guide/overview.rst @@ -14,6 +14,11 @@ Library Features RSL-RL is intentionally kept minimal and focuses on a small set of components that cover common robotics workflows while remaining easy to adapt. The following sections summarize the main features currently available. +.. note:: + Adding new algorithms, models, or loggers is straightforward and does not require modifying the library itself. + Custom classes can simply be passed as part of the configuration, enabling users to work with the pip version of the + library. + Algorithms ^^^^^^^^^^ @@ -198,7 +203,9 @@ not constrain the way an **Extension** may be implemented, allowing for arbitrar Utils ^^^^^ **Utils** include various helpers for the library, such as a :class:`~rsl_rl.utils.logger.Logger` to record the learning -process, or functions to resolve configuration settings. +process, or functions to resolve configuration settings. For example, the :func:`~rsl_rl.utils.utils.resolve_callable` +function allows users to pass classes via the configuration dictionary, enabling the use of custom models, loggers, etc. +without modifying the library. .. _example-integration: diff --git a/licenses/dependencies/tensorboard-license.txt b/licenses/dependencies/tensorboard-license.txt new file mode 100644 index 00000000..27e80ef7 --- /dev/null +++ b/licenses/dependencies/tensorboard-license.txt @@ -0,0 +1,203 @@ +Copyright 2017 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017, The TensorFlow Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index dedc35e4..2295aa0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "rsl-rl-lib" -version = "5.3.0" +version = "5.4.0" keywords = ["reinforcement-learning", "robotics"] maintainers = [ { name="Clemens Schwarke", email="cschwarke@ethz.ch" }, @@ -29,6 +29,7 @@ dependencies = [ "torchvision>=0.5.0", "tensordict>=0.7.0", "numpy>=1.16.4", + "tensorboard", "GitPython", "onnx", "onnxscript>=0.5.4", diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py index e473b912..2e7a5d74 100644 --- a/rsl_rl/utils/__init__.py +++ b/rsl_rl/utils/__init__.py @@ -5,6 +5,8 @@ """Helper functions.""" +from .log_writer import LogWriter +from .neptune_log_writer import NeptuneLogWriter from .utils import ( check_nan, compile_model, @@ -16,8 +18,12 @@ split_and_pad_trajectories, unpad_trajectories, ) +from .wandb_log_writer import WandbLogWriter __all__ = [ + "LogWriter", + "NeptuneLogWriter", + "WandbLogWriter", "check_nan", "compile_model", "get_param", diff --git a/rsl_rl/utils/log_writer.py b/rsl_rl/utils/log_writer.py new file mode 100644 index 00000000..c713fdbf --- /dev/null +++ b/rsl_rl/utils/log_writer.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021-2026, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pathlib +from abc import ABC, abstractmethod + + +class LogWriter(ABC): + """Abstract base class for logging backends. + + Log writers are configured via ``cfg["logger"]``, a dict with a ``"class_name"`` key pointing to the subclass and + any additional keys forwarded as constructor kwargs. The class is resolved via + :func:`~rsl_rl.utils.resolve_callable`. Only :meth:`add_scalar` must be implemented; all other methods are no-ops. + """ + + @abstractmethod + def add_scalar(self, tag: str, scalar_value: float, global_step: int) -> None: + """Log a scalar metric. + + Args: + tag: Name of the metric. + scalar_value: Value of the metric. + global_step: Current training iteration. + """ + + def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: + """Upload environment and training configuration. Called once at training start.""" + + def save_model(self, model_path: str, it: int) -> None: + """Upload a model checkpoint.""" + + def save_file(self, path: str) -> None: + """Upload an arbitrary file.""" + + def save_video(self, video: pathlib.Path, it: int) -> None: + """Upload a video file.""" + + def stop(self) -> None: + """Finalize and close the logging run.""" diff --git a/rsl_rl/utils/logger.py b/rsl_rl/utils/logger.py index cf025fd1..b8f073cd 100644 --- a/rsl_rl/utils/logger.py +++ b/rsl_rl/utils/logger.py @@ -12,9 +12,12 @@ import statistics import time import torch +import warnings from collections import deque import rsl_rl +from rsl_rl.utils.log_writer import LogWriter +from rsl_rl.utils.utils import resolve_callable class Logger: @@ -42,6 +45,9 @@ def __init__( self.tot_timesteps = 0 self.tot_time = 0 + self.writer: LogWriter | None = None + self.logger_type: str | None = None + # Create buffers self.ep_extras = [] self.rewbuffer = deque(maxlen=100) @@ -61,38 +67,56 @@ def __init__( self.disable_logs = is_distributed and gpu_global_rank != 0 def init_logging_writer(self) -> None: - """Initialize the logging writer, which can be either Tensorboard, W&B or Neptune and save the code state. + """Initialize the logging writer and save the code state. - If the writer is either W&B or Neptune, the configuration and code state are uploaded as well. + .. note:: + The writer is constructed from ``cfg["logger"]``, which should be a dict with a ``"class_name"`` key plus + any additional constructor kwargs (see :class:`~rsl_rl.utils.LogWriter`). The plain string aliases + ``"wandb"`` and ``"neptune"`` are deprecated; use ``"WandbLogWriter"`` and ``"NeptuneLogWriter"`` in the + dict form instead. ``"tensorboard"`` (the default) is still accepted as a plain string. """ if self.log_dir is not None and not self.disable_logs: - self.logger_type = self.cfg.get("logger", "tensorboard") - self.logger_type = self.logger_type.lower() - if self.logger_type == "neptune": - from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter - - self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) - elif self.logger_type == "wandb": - from rsl_rl.utils.wandb_utils import WandbSummaryWriter - - self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) - elif self.logger_type == "tensorboard": + logger_cfg = self.cfg.get("logger", "tensorboard") + self.logger_type = logger_cfg if isinstance(logger_cfg, str) else logger_cfg.pop("class_name") + + # Handle deprecated plain string logger types for W&B and Neptune + if self.logger_type == "wandb" and isinstance(logger_cfg, str): + warnings.warn( + "cfg['logger'] = 'wandb' is deprecated. " + "Use cfg['logger'] = {'class_name': 'WandbLogWriter', 'project_name': ...} instead.", + DeprecationWarning, + stacklevel=2, + ) + self.logger_type = "WandbLogWriter" + logger_cfg = {"project_name": self.cfg.get("wandb_project")} + elif self.logger_type == "neptune" and isinstance(logger_cfg, str): + warnings.warn( + "cfg['logger'] = 'neptune' is deprecated. " + "Use cfg['logger'] = {'class_name': 'NeptuneLogWriter', 'project_name': ...} instead.", + DeprecationWarning, + stacklevel=2, + ) + self.logger_type = "NeptuneLogWriter" + logger_cfg = {"project_name": self.cfg.get("neptune_project")} + + if self.logger_type == "tensorboard": from torch.utils.tensorboard import SummaryWriter - self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) + self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) # type: ignore else: - raise ValueError("Logger type not found. Please choose 'wandb', 'neptune', or 'tensorboard'.") + writer_class = resolve_callable(self.logger_type) + self.writer = writer_class(log_dir=self.log_dir, **logger_cfg) # type: ignore else: self.writer = None # Save code state files_to_upload = self._store_code_state() - # Upload configuration and code state to external logging service if applicable - if self.writer is not None and self.logger_type in ["wandb", "neptune"]: - self.writer.store_config(self.env_cfg, self.cfg) # type: ignore + # Upload configuration and code state to external logging service if supported + if isinstance(self.writer, LogWriter): + self.writer.store_config(self.env_cfg, self.cfg) for path in files_to_upload: - self.writer.save_file(path) # type: ignore + self.writer.save_file(path) def process_env_step( self, @@ -200,7 +224,7 @@ def log( self.writer.add_scalar("Rnd/weight", rnd_weight, it) # type: ignore self.writer.add_scalar("Train/mean_reward", statistics.mean(self.rewbuffer), it) self.writer.add_scalar("Train/mean_episode_length", statistics.mean(self.lenbuffer), it) - if self.logger_type != "wandb": + if self.logger_type != "WandbLogWriter": self.writer.add_scalar( "Train/mean_reward/time", statistics.mean(self.rewbuffer), int(self.tot_time) ) @@ -255,23 +279,23 @@ def log( ) print(log_string) - # Upload available videos - if self.logger_type == "wandb": + # Upload available videos to external logging service if supported + if isinstance(self.writer, LogWriter): for video in pathlib.Path(self.log_dir).rglob("*.mp4"): # type: ignore - self.writer.save_video(video, it) # type: ignore + self.writer.save_video(video, it) # Clear extras buffer self.ep_extras.clear() def save_model(self, path: str, it: int) -> None: - """Save the model to external logging services if specified.""" - if self.writer is not None and self.logger_type in ["neptune", "wandb"]: - self.writer.save_model(path, it) # type: ignore + """Save the model to external logging service if specified.""" + if isinstance(self.writer, LogWriter): + self.writer.save_model(path, it) def stop_logging_writer(self) -> None: """Stop the logging writer.""" - if self.writer is not None and self.logger_type in ["neptune", "wandb"]: - self.writer.stop() # type: ignore + if isinstance(self.writer, LogWriter): + self.writer.stop() def _store_code_state(self) -> list[str]: """Store the current git diff of the code repositories involved in the experiment.""" diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_log_writer.py similarity index 79% rename from rsl_rl/utils/neptune_utils.py rename to rsl_rl/utils/neptune_log_writer.py index fb23b564..edd1c194 100644 --- a/rsl_rl/utils/neptune_utils.py +++ b/rsl_rl/utils/neptune_log_writer.py @@ -10,27 +10,26 @@ from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter +from rsl_rl.utils.log_writer import LogWriter + try: import neptune except ModuleNotFoundError: - raise ModuleNotFoundError("neptune-client is required to log to Neptune.") from None + neptune = None -class NeptuneSummaryWriter(SummaryWriter): +class NeptuneLogWriter(SummaryWriter, LogWriter): """Summary writer for Neptune.""" - def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: + def __init__(self, log_dir: str, project_name: str) -> None: """Initialize a Neptune run for logging.""" - super().__init__(log_dir, flush_secs=flush_secs) + if neptune is None: + raise ModuleNotFoundError("neptune-client is required to log to Neptune.") + super().__init__(log_dir, flush_secs=10) # Get the run name run_name = os.path.split(log_dir)[-1] - # Get neptune project and entity - try: - project = cfg["neptune_project"] - except KeyError: - raise KeyError("Please specify neptune_project in the runner config, e.g. legged_gym.") from None try: token = os.environ["NEPTUNE_API_TOKEN"] except KeyError: @@ -45,7 +44,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: ) from None # Initialize neptune - neptune_project = entity + "/" + project + neptune_project = entity + "/" + project_name self.run = neptune.init_run(project=neptune_project, api_token=token) self.run["log_dir"].log(run_name) @@ -55,14 +54,6 @@ def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: "Train/mean_episode_length/time": "Train/mean_episode_length_time", } - def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: - """Upload environment and training configuration to Neptune.""" - self.run["train_cfg"] = train_cfg - try: - self.run["env_cfg"] = env_cfg.to_dict() # type: ignore - except Exception: - self.run["env_cfg"] = asdict(env_cfg) # type: ignore - def add_scalar( self, tag: str, @@ -72,18 +63,16 @@ def add_scalar( new_style: bool = False, ) -> None: """Log a scalar to both TensorBoard and Neptune.""" - super().add_scalar( - tag, - scalar_value, - global_step=global_step, - walltime=walltime, - new_style=new_style, - ) + super().add_scalar(tag, scalar_value, global_step=global_step, walltime=walltime, new_style=new_style) self.run[self._map_path(tag)].log(scalar_value, step=global_step) - def stop(self) -> None: - """Finish the active Neptune run.""" - self.run.stop() + def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: + """Upload environment and training configuration to Neptune.""" + self.run["train_cfg"] = train_cfg + try: + self.run["env_cfg"] = env_cfg.to_dict() # type: ignore + except Exception: + self.run["env_cfg"] = asdict(env_cfg) # type: ignore def save_model(self, model_path: str, it: int) -> None: """Upload a model checkpoint artifact to Neptune.""" @@ -94,6 +83,10 @@ def save_file(self, path: str) -> None: name = path.rsplit("/", 1)[-1].split(".")[0] self.run["git_diff/" + name].upload(path) + def stop(self) -> None: + """Finish the active Neptune run.""" + self.run.stop() + def _map_path(self, path: str) -> str: """Map metric names to Neptune-compatible keys.""" if path in self.name_map: diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_log_writer.py similarity index 76% rename from rsl_rl/utils/wandb_utils.py rename to rsl_rl/utils/wandb_log_writer.py index 178eef69..a8c6b551 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_log_writer.py @@ -11,27 +11,26 @@ from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter +from rsl_rl.utils.log_writer import LogWriter + try: import wandb except ModuleNotFoundError: - raise ModuleNotFoundError("wandb package is required to log to Weights and Biases.") from None + wandb = None -class WandbSummaryWriter(SummaryWriter): +class WandbLogWriter(SummaryWriter, LogWriter): """Summary writer for W&B.""" - def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: + def __init__(self, log_dir: str, project_name: str) -> None: """Initialize a W&B run for logging.""" - super().__init__(log_dir, flush_secs=flush_secs) + if wandb is None: + raise ModuleNotFoundError("wandb package is required to log to Weights and Biases.") + super().__init__(log_dir, flush_secs=10) # Get the run name run_name = os.path.split(log_dir)[-1] - # Get wandb project and entity - try: - project = cfg["wandb_project"] - except KeyError: - raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") from None try: entity = os.environ["WANDB_USERNAME"] except KeyError: @@ -39,7 +38,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: # Initialize wandb wandb.init( - project=project, + project=project_name, entity=entity, name=run_name, config={"log_dir": log_dir}, @@ -49,14 +48,6 @@ def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: # Initialize set to keep track of logged videos self.logged_videos: set[str] = set() - def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: - """Upload environment and training configuration to W&B.""" - wandb.config.update({"train_cfg": train_cfg}) - try: - wandb.config.update({"env_cfg": env_cfg.to_dict()}) # type: ignore - except Exception: - wandb.config.update({"env_cfg": asdict(env_cfg)}) # type: ignore - def add_scalar( self, tag: str, @@ -66,18 +57,16 @@ def add_scalar( new_style: bool = False, ) -> None: """Log a scalar to both TensorBoard and W&B.""" - super().add_scalar( - tag, - scalar_value, - global_step=global_step, - walltime=walltime, - new_style=new_style, - ) + super().add_scalar(tag, scalar_value, global_step=global_step, walltime=walltime, new_style=new_style) wandb.log({tag: scalar_value}, step=global_step) - def stop(self) -> None: - """Finish the active W&B run.""" - wandb.finish() + def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: + """Upload environment and training configuration to W&B.""" + wandb.config.update({"train_cfg": train_cfg}) + try: + wandb.config.update({"env_cfg": env_cfg.to_dict()}) # type: ignore + except Exception: + wandb.config.update({"env_cfg": asdict(env_cfg)}) # type: ignore def save_model(self, model_path: str, it: int) -> None: """Upload a model checkpoint artifact to W&B.""" @@ -92,3 +81,7 @@ def save_video(self, video: pathlib.Path, it: int) -> None: if video.name not in self.logged_videos: wandb.log({"video": wandb.Video(str(video), format="mp4")}, step=it) self.logged_videos.add(video.name) + + def stop(self) -> None: + """Finish the active W&B run.""" + wandb.finish() diff --git a/ruff.toml b/ruff.toml index 2fe491b7..08277635 100644 --- a/ruff.toml +++ b/ruff.toml @@ -32,9 +32,10 @@ select = [ "RUF", ] ignore = ["B006", - "B007", - "B028", - "ANN401", + "B007", + "B027", + "B028", + "ANN401", "D100", "D203", "D213",