diff --git a/conf/train/default.yaml b/conf/train/default.yaml index 6b37290..26126c0 100644 --- a/conf/train/default.yaml +++ b/conf/train/default.yaml @@ -51,6 +51,7 @@ logging: logger: _target_: pytorch_lightning.loggers.WandbLogger + name: null project: ${core.project_name} entity: null log_model: ${..upload.checkpoint} diff --git a/src/nn_template/data/datamodule.py b/src/nn_template/data/datamodule.py index e155580..a5d61e5 100644 --- a/src/nn_template/data/datamodule.py +++ b/src/nn_template/data/datamodule.py @@ -33,6 +33,10 @@ def __init__( # example self.val_percentage: float = val_percentage + @property + def name(self) -> str: + return "mnist_data" + def prepare_data(self) -> None: # download only pass diff --git a/src/nn_template/pl_modules/pl_module.py b/src/nn_template/pl_modules/pl_module.py index da6a2fe..9b39cc3 100644 --- a/src/nn_template/pl_modules/pl_module.py +++ b/src/nn_template/pl_modules/pl_module.py @@ -28,6 +28,10 @@ def __init__(self, *args, **kwargs) -> None: self.val_accuracy = metric.clone() self.test_accuracy = metric.clone() + @property + def name(self) -> str: + return "simple_cnn" + def forward(self, x: torch.Tensor) -> torch.Tensor: """Method for the forward pass. diff --git a/src/nn_template/run.py b/src/nn_template/run.py index 662eacb..eec825a 100644 --- a/src/nn_template/run.py +++ b/src/nn_template/run.py @@ -114,6 +114,7 @@ def run(cfg: DictConfig) -> str: storage_dir: str = cfg.core.storage_dir + cfg.train.logging.logger.name = f"{datamodule.name}-{model.name}" logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=resume_run_version) pylogger.info("Instantiating the ")