Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions ignite/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,12 @@ class ProgressBar(BaseLogger):
]

def __init__(
self,
persist: bool = False,
bar_format: (
str | None
) = "{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]",
**tqdm_kwargs: Any,
):
self,
persist: bool = False,
bar_format: str | None = "{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]",
show_epoch: bool = True,
**tqdm_kwargs: Any,
):
try:
from tqdm.autonotebook import tqdm
except ImportError:
Expand All @@ -136,6 +135,7 @@ def __init__(
self.pbar = None
self.persist = persist
self.bar_format = bar_format
self.show_epoch = show_epoch
self.tqdm_kwargs = tqdm_kwargs

def _reset(self, pbar_total: int | None) -> None:
Expand Down Expand Up @@ -212,12 +212,13 @@ def attach( # type: ignore[override]
raise ValueError(f"Logging event {event_name} should be called before closing event {closing_event_name}")

log_handler = _OutputHandler(
desc,
metric_names,
output_transform,
closing_event_name=closing_event_name,
state_attributes=state_attributes,
)
desc,
metric_names,
output_transform,
closing_event_name=closing_event_name,
state_attributes=state_attributes,
show_epoch=self.show_epoch,
)

super(ProgressBar, self).attach(engine, log_handler, event_name)
engine.add_event_handler(closing_event_name, self._close)
Expand Down Expand Up @@ -261,20 +262,22 @@ class _OutputHandler(BaseOutputHandler):
"""

def __init__(
self,
description: str,
metric_names: str | list[str] | None = None,
output_transform: Callable | None = None,
closing_event_name: Events | CallableEventWithFilter = Events.EPOCH_COMPLETED,
state_attributes: list[str] | None = None,
):
self,
description: str,
metric_names: str | list[str] | None = None,
output_transform: Callable | None = None,
closing_event_name: Events | CallableEventWithFilter = Events.EPOCH_COMPLETED,
state_attributes: list[str] | None = None,
show_epoch: bool = True,
):
if metric_names is None and output_transform is None:
# This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler
metric_names = []
super().__init__(
description, metric_names, output_transform, global_step_transform=None, state_attributes=state_attributes
)
self.closing_event_name = closing_event_name
self.show_epoch = show_epoch

@staticmethod
def get_max_number_events(event_name: str | Events | CallableEventWithFilter, engine: Engine) -> int | None:
Expand All @@ -294,7 +297,7 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: str | Events

desc = self.tag or default_desc
max_num_of_closing_events = self.get_max_number_events(self.closing_event_name, engine)
if max_num_of_closing_events and max_num_of_closing_events > 1:
if self.show_epoch and max_num_of_closing_events and max_num_of_closing_events > 1:
global_step = engine.state.get_event_attrib_value(self.closing_event_name)
desc += f" [{global_step}/{max_num_of_closing_events}]"
logger.pbar.set_description(desc) # type: ignore[attr-defined]
Expand Down