Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/loggers.info
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# all supported loggers. this list is here as a reference, but they are not installed in CI

litlogger >= 0.1.7
litlogger >= 2026.03.17
neptune >=1.0.0
comet-ml >=3.31.0
mlflow >=1.0.0
Expand Down
150 changes: 74 additions & 76 deletions src/lightning/pytorch/loggers/litlogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from argparse import Namespace
from collections.abc import Mapping
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
Expand Down Expand Up @@ -108,6 +108,7 @@ def training_step(self, batch, batch_idx: int):
self._sub_dir = None
self._prefix = ""
self._fs = get_filesystem(self._root_dir)
self._experiment: Optional[Experiment] = None
self._step = -1
self._metadata = metadata or {}
self._is_ready = False
Expand Down Expand Up @@ -163,14 +164,40 @@ def sub_dir(self) -> Optional[str]:
"""Gets the sub directory where the TensorBoard experiments are saved."""
return self._sub_dir

@property
def _experiment_name(self) -> str:
if self.version is None:
return self.name
return f"{self.name}-{self.version}"

@staticmethod
def _default_artifact_key(path: str) -> str:
try:
rel = os.path.relpath(path)
except ValueError:
rel = None
key = rel if rel is not None and not rel.startswith("..") else os.path.basename(path)
return key.replace("\\", "/")

def _model_key(self) -> str:
return self._experiment_name

@staticmethod
def _model_version(version: Optional[str], step: Optional[int]) -> Optional[str]:
if version is not None:
return version
if step is not None and step >= 0:
return str(step)
return None

@property
@rank_zero_experiment
def experiment(self) -> Optional["Experiment"]:
"""Returns the underlying litlogger Experiment object."""
import litlogger

if litlogger.experiment is not None:
return litlogger.experiment
if self._experiment is not None:
return self._experiment

if not self._is_ready:
self._is_ready = True
Expand All @@ -182,24 +209,31 @@ def experiment(self) -> Optional["Experiment"]:
if self.version is None:
# Generate version as proper RFC 3339 timestamp with Z suffix (required by protobuf)
timestamp = datetime.now(timezone.utc).isoformat(timespec="milliseconds")
self._version = timestamp.replace("+00:00", "Z")
self._version = timestamp.replace(":", "-").replace("+00:00", "Z")

litlogger.init(
name=self._name,
root_dir=self._root_dir,
self._experiment = litlogger.Experiment(
name=self._experiment_name,
teamspace=self._teamspace,
metadata={k: str(v) for k, v in self._metadata.items()},
store_step=True,
store_created_at=True,
log_dir=self.log_dir,
save_logs=self._save_logs,
)
self._experiment.print_url()

return litlogger.experiment
return self._experiment

def _require_experiment(self) -> "Experiment":
experiment = self.experiment
if experiment is None:
raise RuntimeError("Experiment is not initialized")
return experiment

@property
@rank_zero_only
def url(self) -> str:
return self.experiment.url
return self._require_experiment().url

# ──────────────────────────────────────────────────────────────────────────────
# Override methods from Logger
Expand All @@ -208,18 +242,17 @@ def url(self) -> str:
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
import litlogger

assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

# Ensure experiment is initialized
_ = self.experiment
experiment = self._require_experiment()

self._step = self._step + 1 if step is None else step

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
metrics = {k: v.item() if isinstance(v, Tensor) else v for k, v in metrics.items()}
litlogger.log_metrics(metrics, step=self._step)
for key, value in metrics.items():
experiment[key].append(value, step=self._step)

@override
@rank_zero_only
Expand All @@ -231,8 +264,9 @@ def log_hyperparams(
"""Log hyperparams."""
if isinstance(params, Namespace):
params = params.__dict__
params.update(self._metadata or {})
self._metadata = params
experiment = self._require_experiment()
for key, value in params.items():
experiment[key] = str(value)

@override
@rank_zero_only
Expand All @@ -247,13 +281,11 @@ def save(self) -> None:
@override
@rank_zero_only
def finalize(self, status: Optional[str] = None) -> None:
import litlogger

if litlogger.experiment is not None:
if self._experiment is not None:
# log checkpoints as artifacts before finalizing
if self._checkpoint_callback:
self._scan_and_log_checkpoints(self._checkpoint_callback)
litlogger.finalize(status)
self._experiment.finalize(status)

# ──────────────────────────────────────────────────────────────────────────────
# Public methods
Expand All @@ -267,8 +299,9 @@ def log_metadata(
"""Log hyperparams."""
if isinstance(params, Namespace):
params = params.__dict__
params.update(self._metadata or {})
self._metadata = params
experiment = self._require_experiment()
for key, value in params.items():
experiment[key] = str(value)

@rank_zero_only
def log_model(
Expand All @@ -289,10 +322,14 @@ def log_model(
metadata: Optional metadata dictionary to store with the model.

"""
import litlogger
from litlogger import Model

_ = self.experiment
litlogger.log_model(model, staging_dir, verbose, version, metadata)
self._require_experiment()[self._model_key()] = Model(
model,
version=self._model_version(version, self._step),
metadata=cast(Optional[dict[str, str]], metadata),
staging_dir=staging_dir,
)

@rank_zero_only
def log_model_artifact(
Expand All @@ -309,10 +346,9 @@ def log_model_artifact(
version: Optional version string for the model. Defaults to the experiment version.

"""
import litlogger
from litlogger import Model

_ = self.experiment
litlogger.log_model_artifact(path, verbose, version)
self._require_experiment()[self._model_key()] = Model(path, version=self._model_version(version, self._step))

@rank_zero_only
def get_file(self, path: str, verbose: bool = True) -> str:
Expand All @@ -326,46 +362,8 @@ def get_file(self, path: str, verbose: bool = True) -> str:
str: The local path where the file was saved.

"""
import litlogger

_ = self.experiment
return litlogger.get_file(path, verbose=verbose)

@rank_zero_only
def get_model(self, staging_dir: Optional[str] = None, verbose: bool = False, version: Optional[str] = None) -> Any:
"""Download and load a model object using litmodels.

Args:
staging_dir: Optional directory where the model will be downloaded.
verbose: Whether to show progress bar.
version: Optional version string for the model.

Returns:
The loaded model object.

"""
import litlogger

_ = self.experiment
return litlogger.get_model(staging_dir, verbose, version)

@rank_zero_only
def get_model_artifact(self, path: str, verbose: bool = False, version: Optional[str] = None) -> str:
"""Download a model artifact file or directory from cloud storage using litmodels.

Args:
path: Path where the model should be saved locally.
verbose: Whether to show progress bar during download.
version: Optional version string for the model.

Returns:
str: The local path where the model was saved.

"""
import litlogger

_ = self.experiment
return litlogger.get_model_artifact(path, verbose, version)
file = cast(Any, self._require_experiment()[self._default_artifact_key(path)])
return file.save(path)

@rank_zero_only
def log_file(self, path: str) -> None:
Expand All @@ -382,10 +380,9 @@ def log_file(self, path: str) -> None:
logger.log_file('config.yaml')

"""
import litlogger
from litlogger import File

_ = self.experiment
litlogger.log_file(path)
self._require_experiment()[self._default_artifact_key(path)] = File(path)

# ──────────────────────────────────────────────────────────────────────────────
# Callback methods
Expand All @@ -412,11 +409,12 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
"""Find new checkpoints from the callback and log them as model artifacts."""
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)

for timestamp, path_ckpt, _score, tag in checkpoints:
if not self._checkpoint_name:
self._checkpoint_name = self.experiment.name
# Ensure the version tag is unique by appending a timestamp
unique_tag = f"{tag}-{int(datetime.now(timezone.utc).timestamp())}"
self.log_model_artifact(path_ckpt, verbose=True, version=unique_tag)
for timestamp, path_ckpt, _score, _tag in checkpoints:
experiment = self._require_experiment()
checkpoint_key = self._checkpoint_name or experiment.name
checkpoint_step = getattr(checkpoint_callback, "_last_global_step_saved", None)
from litlogger import Model

experiment[checkpoint_key] = Model(path_ckpt, version=self._model_version(None, checkpoint_step))
# remember logged models - timestamp needed in case filename didn't change
self._logged_model_time[path_ckpt] = timestamp
29 changes: 14 additions & 15 deletions tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,23 @@ def litlogger_mock(monkeypatch):
experiment_mock.url = "https://lightning.ai/test/experiments/test-experiment"
experiment_mock.name = "test-experiment"
experiment_mock.version = "2024-01-01T00:00:00.000Z"
experiment_mock.get_file.return_value = "/path/to/file"
experiment_mock.get_model.return_value = MagicMock()
experiment_mock.get_model_artifact.return_value = "/path/to/artifact"
experiment_mock.series_mocks = {}

def get_series(key):
if key not in experiment_mock.series_mocks:
experiment_mock.series_mocks[key] = MagicMock()
return experiment_mock.series_mocks[key]

experiment_mock.__getitem__.side_effect = get_series

litlogger = ModuleType("litlogger")
litlogger.experiment = None
litlogger.Experiment = MagicMock

def mock_init(**kwargs):
litlogger.experiment = experiment_mock
return experiment_mock

litlogger.init = Mock(side_effect=mock_init)
litlogger.log_metrics = Mock()
litlogger.log_file = Mock()
litlogger.get_file = Mock(return_value="/path/to/file")
litlogger.log_model = Mock()
litlogger.get_model = Mock(return_value=MagicMock())
litlogger.log_model_artifact = Mock()
litlogger.get_model_artifact = Mock(return_value="/path/to/artifact")
litlogger.finalize = Mock()
litlogger.Experiment = Mock(return_value=experiment_mock)
litlogger.File = Mock()
litlogger.Model = Mock()
monkeypatch.setitem(sys.modules, "litlogger", litlogger)

# Create generator submodule
Expand Down
Loading
Loading