Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,10 @@ def mark_staged(self, unsafe: bool = False) -> Self:
return self

def mark_running(
self, no_runner_required: bool = False, unsafe: bool = False
self,
no_runner_required: bool = False,
unsafe: bool = False,
started_time: datetime | None = None,
) -> Self:
"""Mark trial has started running.

Expand All @@ -566,6 +569,10 @@ def mark_running(
no_runner_required: Whether to skip the check for presence of a
``Runner`` on the experiment.
unsafe: Ignore sanity checks on state transitions.
started_time: When the trial actually started running. Defaults
to ``datetime.now()`` if not provided. Useful for runners
that confirm deployment asynchronously and know the real
start time.

Returns:
The trial instance.
Expand All @@ -586,7 +593,9 @@ def mark_running(
f"Can only mark this trial as running when {prev_step_str}."
)
self._status = TrialStatus.RUNNING
self._time_run_started = datetime.now()
self._time_run_started = (
started_time if started_time is not None else datetime.now()
)
return self

def mark_completed(
Expand Down Expand Up @@ -749,7 +758,12 @@ def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> S
self.mark_staged(unsafe=unsafe)
elif status == TrialStatus.RUNNING:
no_runner_required = kwargs.get("no_runner_required", False)
self.mark_running(no_runner_required=no_runner_required, unsafe=unsafe)
started_time = kwargs.get("started_time")
self.mark_running(
no_runner_required=no_runner_required,
unsafe=unsafe,
started_time=started_time,
)
elif status == TrialStatus.ABANDONED:
self.mark_abandoned(reason=kwargs.get("reason"), unsafe=unsafe)
elif status == TrialStatus.FAILED:
Expand Down
12 changes: 12 additions & 0 deletions ax/core/tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,18 @@ def test_update_trial_status_on_clone(self) -> None:
test_trial._time_run_started = self.trial._time_run_started
self.assertEqual(self.trial, test_trial)

def test_mark_running_custom_started_time(self) -> None:
custom_time = datetime(2026, 5, 30, 4, 54, 38)
self.trial.mark_running(no_runner_required=True, started_time=custom_time)
self.assertEqual(self.trial.status, TrialStatus.RUNNING)
self.assertEqual(self.trial.time_run_started, custom_time)

def test_mark_running_default_started_time(self) -> None:
self.trial.mark_running(no_runner_required=True)
self.assertEqual(self.trial.status, TrialStatus.RUNNING)
self.assertIsNotNone(self.trial.time_run_started)
self.assertIsInstance(self.trial.time_run_started, datetime)

def test_mark_complete_custom_date(self) -> None:
self.trial.mark_running(no_runner_required=True)
with self.subTest("custom completion time"):
Expand Down
43 changes: 39 additions & 4 deletions ax/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,10 +1437,18 @@ def _apply_trial_statuses(
"""
updated_trial_indices = set()
for status, trial_idcs in polled_status_to_trial_idcs.items():
if status.is_candidate or status.is_deployed:
# No need to consider candidate, staged or running trials here (none of
# these trials should actually be candidates, but we can filter on that)
if status.is_candidate:
continue
if status.is_deployed:
# Only process trials whose status actually changed
# (e.g. STAGED -> RUNNING), skip no-ops.
trial_idcs = {
idx
for idx in trial_idcs
if self.experiment.trials[idx].status != status
}
if not trial_idcs:
continue

if len(trial_idcs) > 0:
idcs = make_indices_str(indices=trial_idcs)
Expand All @@ -1459,7 +1467,34 @@ def _apply_trial_statuses(
# we fall back to marking the without a reason.
trial.mark_as(status=status, unsafe=True)
else:
trial.mark_as(status=status, unsafe=True)
kwargs: dict[str, datetime] = {}
if status == TrialStatus.RUNNING:
started_time_raw = trial.run_metadata.get(Keys.START_TIME_STR)
if isinstance(started_time_raw, str):
try:
kwargs["started_time"] = datetime.strptime(
started_time_raw, "%Y-%m-%d %H:%M:%S"
)
except ValueError:
self.logger.warning(
f"Could not parse {Keys.START_TIME_STR} "
f"{started_time_raw!r} for trial "
f"{trial.index}; defaulting to now."
)
elif isinstance(started_time_raw, datetime):
kwargs["started_time"] = started_time_raw
elif isinstance(started_time_raw, (int, float)):
kwargs["started_time"] = datetime.fromtimestamp(
started_time_raw
)
elif started_time_raw is not None:
self.logger.warning(
f"Unexpected type for {Keys.START_TIME_STR} "
f"in trial {trial.index} run_metadata: "
f"{type(started_time_raw).__name__}; "
f"defaulting to now."
)
trial.mark_as(status=status, unsafe=True, **kwargs)
return updated_trial_indices

def _identify_trial_indices_to_fetch(
Expand Down
115 changes: 114 additions & 1 deletion ax/orchestration/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
import time
from collections.abc import Callable, Iterable
from datetime import timedelta
from datetime import datetime, timedelta
from math import ceil
from tempfile import NamedTemporaryFile
from typing import Any, cast
Expand Down Expand Up @@ -1686,6 +1686,119 @@ def test_poll_and_process_results_with_reasons(self) -> None:
)
self.assertIsNone(orchestrator.experiment.trials[completed_idx].status_reason)

def test_apply_trial_statuses_staged_to_running(self) -> None:
"""Test that _apply_trial_statuses processes STAGED->RUNNING
transitions and uses time_started from run_metadata."""
gs = self.sobol_GS_no_parallelism
orchestrator = Orchestrator(
experiment=self.branin_experiment,
generation_strategy=gs,
options=OrchestratorOptions(
init_seconds_between_polls=0,
**self.orchestrator_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
grs = gs.gen(experiment=orchestrator.experiment)
trial = orchestrator.experiment.new_trial(generator_run=grs[0][0])
trial.mark_staged(unsafe=True)
self.assertEqual(trial.status, TrialStatus.STAGED)

custom_time_str = "2026-05-30 04:54:38"
custom_time = datetime(2026, 5, 30, 4, 54, 38)
trial.update_run_metadata({Keys.START_TIME_STR: custom_time_str})

updated = orchestrator._apply_trial_statuses(
{TrialStatus.RUNNING: {trial.index}}
)
self.assertIn(trial.index, updated)
self.assertEqual(trial.status, TrialStatus.RUNNING)
self.assertEqual(trial.time_run_started, custom_time)

def test_apply_trial_statuses_skips_running_noop(self) -> None:
"""Test that _apply_trial_statuses skips RUNNING->RUNNING no-ops."""
gs = self.sobol_GS_no_parallelism
orchestrator = Orchestrator(
experiment=self.branin_experiment,
generation_strategy=gs,
options=OrchestratorOptions(
init_seconds_between_polls=0,
**self.orchestrator_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
grs = gs.gen(experiment=orchestrator.experiment)
trial = orchestrator.experiment.new_trial(generator_run=grs[0][0])
trial.mark_running(no_runner_required=True)
original_time = trial.time_run_started

updated = orchestrator._apply_trial_statuses(
{TrialStatus.RUNNING: {trial.index}}
)
self.assertEqual(len(updated), 0)
self.assertEqual(trial.time_run_started, original_time)

def test_apply_trial_statuses_staged_to_running_no_time_started(self) -> None:
"""Test that STAGED->RUNNING defaults to datetime.now() when
time_started is not in run_metadata."""
gs = self.sobol_GS_no_parallelism
orchestrator = Orchestrator(
experiment=self.branin_experiment,
generation_strategy=gs,
options=OrchestratorOptions(
init_seconds_between_polls=0,
**self.orchestrator_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
grs = gs.gen(experiment=orchestrator.experiment)
trial = orchestrator.experiment.new_trial(generator_run=grs[0][0])
trial.mark_staged(unsafe=True)

before = datetime.now()
updated = orchestrator._apply_trial_statuses(
{TrialStatus.RUNNING: {trial.index}}
)
after = datetime.now()

self.assertIn(trial.index, updated)
self.assertEqual(trial.status, TrialStatus.RUNNING)
started = trial.time_run_started
self.assertIsNotNone(started)
assert started is not None
self.assertGreaterEqual(started, before)
self.assertLessEqual(started, after)

def test_apply_trial_statuses_staged_to_running_int_timestamp(self) -> None:
"""Test that STAGED->RUNNING handles integer timestamps in
run_metadata by converting from epoch seconds."""
gs = self.sobol_GS_no_parallelism
orchestrator = Orchestrator(
experiment=self.branin_experiment,
generation_strategy=gs,
options=OrchestratorOptions(
init_seconds_between_polls=0,
**self.orchestrator_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
grs = gs.gen(experiment=orchestrator.experiment)
trial = orchestrator.experiment.new_trial(generator_run=grs[0][0])
trial.mark_staged(unsafe=True)

epoch = 1748580878
trial.update_run_metadata({Keys.START_TIME_STR: epoch})

updated = orchestrator._apply_trial_statuses(
{TrialStatus.RUNNING: {trial.index}}
)
self.assertIn(trial.index, updated)
self.assertEqual(trial.status, TrialStatus.RUNNING)
self.assertEqual(
trial.time_run_started,
datetime.fromtimestamp(epoch),
)

def test_poll_trial_status_fallback_to_individual_polling(self) -> None:
"""Test that poll_trial_status falls back to individual polling when
batch polling fails, and successfully completes trials."""
Expand Down
1 change: 1 addition & 0 deletions ax/utils/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Keys(StrEnum):
TRIAL_COMPLETION_TIMESTAMP = "trial_completion_timestamp"
UNKNOWN_GENERATION_NODE = "unknown_gen_node"
UNNAMED_ARM = "unnamed_arm"
START_TIME_STR = "start_time"
WARM_START_REFITTING = "warm_start_refitting"
WARMSTART_TRIAL_MODEL_KEY = "generation_model_key"
X_BASELINE = "X_baseline"
Expand Down
Loading