Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 194 additions & 2 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def __init__(
log_scale: bool = False,
logit_scale: bool = False,
digits: int | None = None,
step_size: float | None = None,
is_fidelity: bool = False,
target_value: TParamValue = None,
backfill_value: TParamValue = None,
Expand All @@ -359,6 +360,25 @@ def __init__(
logit_scale: Whether to sample in logit space when drawing
random values of the parameter.
digits: Number of digits to round values to for float type.
Deprecated in favor of ``step_size``; cannot be set together
with ``step_size``.
step_size: If set, the parameter's feasible values are the grid
``{lower + k * step_size : k in N}`` intersected with
``[lower, upper]``. ``cast()`` snaps values to the nearest grid
point (anchored at ``lower``) without clamping to the bounds, so
an out-of-bounds input snaps to an out-of-bounds grid point --
mirroring the non-``step_size`` ``cast()``, which also leaves
out-of-bounds values in place. ``step_size`` must be strictly
positive, and the range must be an exact multiple of it:
``(upper - lower)`` must be an integer multiple of ``step_size``
(within ``EPS``), so that both bounds lie on the grid. For INT
parameters, ``step_size`` must itself be integer-valued.

``step_size`` defines a discrete grid but does not, by itself,
force discrete acquisition optimization. How the optimizer
treats the parameter depends on the grid cardinality
``floor((upper - lower) / step_size) + 1``, and is determined
at the generator level.
is_fidelity: Whether this parameter is a fidelity parameter.
target_value: Target value of this parameter if it is a fidelity.
backfill_value: For parameters added to experiments that have already run
Expand All @@ -378,6 +398,10 @@ def __init__(
raise UserInputError("RangeParameter type must be int or float.")
self._parameter_type = parameter_type
self._digits = digits
# ``_step_size`` must be set before casting ``lower`` / ``upper`` below,
# since ``cast()`` reads it to snap values to the grid.
self._step_size: float | None = None
self._validate_and_set_step_size(step_size=step_size)
self._lower: TNumeric = self.cast(lower)
self._upper: TNumeric = self.cast(upper)
self._log_scale = log_scale
Expand All @@ -393,15 +417,32 @@ def __init__(
self.cast(default_value) if default_value is not None else None
)

# Validate the raw inputs: this rejects invalid user input (e.g. a
# non-integer bound for an INT parameter) before ``cast()`` silently
# truncates it. For the non-deprecated paths ``cast()`` does not move a
# bound that would otherwise pass validation -- FLOAT casting is a no-op
# on the value, and ``step_size`` snapping is skipped for bounds -- so
# validating the raw inputs also guarantees the stored bounds are valid.
self._validate_range_param(
parameter_type=parameter_type,
lower=lower,
upper=upper,
log_scale=log_scale,
logit_scale=logit_scale,
)
# ``upper`` must additionally lie on the ``step_size`` grid (the grid is
# anchored at ``lower``).
self._validate_step_size_on_grid()

def cardinality(self) -> TNumeric:
if self._step_size is not None:
# Values are snapped to the grid {lower + k * step_size}
# intersected with [lower, upper]. Both bounds lie on the grid
# (enforced at construction), so the number of grid points is
# (upper - lower) / step_size + 1.
step_size = none_throws(self._step_size)
return round((float(self.upper) - float(self.lower)) / step_size) + 1

if self.parameter_type == ParameterType.FLOAT:
return inf

Expand Down Expand Up @@ -493,6 +534,19 @@ def digits(self) -> int | None:
"""
return self._digits

@property
def step_size(self) -> float | None:
"""Grid spacing that values are snapped to in ``cast()``.

If set, the parameter's feasible values are the grid
``{lower + k * step_size : k in N}`` intersected with ``[lower, upper]``,
and ``cast()`` snaps values to the nearest grid point (without clamping
to the bounds). Both bounds are guaranteed to be on the grid (the
constructor requires ``(upper - lower)`` to be an integer multiple of
``step_size``). ``None`` means no snapping.
"""
return self._step_size

@property
def log_scale(self) -> bool:
"""Whether the parameter's values should be sampled from log space."""
Expand All @@ -519,14 +573,25 @@ def update_range(
if upper is None:
upper = self._upper

cast_lower = self.cast(lower)
cast_upper = self.cast(upper)
# When ``step_size`` is set, cast the bounds without snapping to the
# (old) grid: bounds anchor the grid and must not be silently moved onto
# it. ``super().cast()`` applies only the type cast. The digits path
# (deprecated) keeps its historical rounding behavior via ``self.cast``.
if self._step_size is not None:
cast_lower = assert_is_instance(super().cast(lower), TNumeric)
cast_upper = assert_is_instance(super().cast(upper), TNumeric)
else:
cast_lower = self.cast(lower)
cast_upper = self.cast(upper)
self._validate_range_param(
lower=cast_lower,
upper=cast_upper,
log_scale=self.log_scale,
logit_scale=self.logit_scale,
)
# The new bounds must lie on the ``step_size`` grid, if one is set.
# Validate before committing so a failed update leaves bounds unchanged.
self._validate_step_size_on_grid(lower=cast_lower, upper=cast_upper)
self._lower = cast_lower
self._upper = cast_upper
return self
Expand All @@ -546,6 +611,95 @@ def set_digits(self, digits: int | None) -> RangeParameter:
self._upper = cast_upper
return self

def set_step_size(self, step_size: float | None) -> RangeParameter:
"""Set the grid spacing that values are snapped to in ``cast()``.

The existing bounds are kept as-is (they anchor the grid and define the
feasible range); they are not snapped onto the new grid. Instead we
require that they already lie on it: ``(upper - lower)`` must be an
integer multiple of the new ``step_size``.

Raises:
UserInputError: If the current bounds do not lie on the new grid.
"""
previous_step_size = self._step_size
self._validate_and_set_step_size(step_size=step_size)
try:
# The current (unchanged) bounds must lie on the new grid.
self._validate_step_size_on_grid()
except UserInputError:
# Leave the parameter unchanged if the new grid is invalid.
self._step_size = previous_step_size
raise
return self

def _validate_and_set_step_size(self, step_size: float | None) -> None:
"""Validate ``step_size`` and store it on ``self._step_size``.

Raises:
UserInputError: If ``step_size`` is non-positive, if it is set
together with ``digits``, or if it is not integer-valued for an
INT parameter.
"""
if step_size is None:
self._step_size = None
return
if self._digits is not None:
raise UserInputError(
f"Cannot set both `digits` and `step_size` on parameter "
f"{self._name}. `digits` is deprecated; use `step_size` only."
)
if step_size <= 0:
raise UserInputError(
f"`step_size` must be strictly positive for parameter "
f"{self._name}. Got: {step_size}."
)
if (
self._parameter_type is ParameterType.INT
and not float(step_size).is_integer()
):
raise UserInputError(
f"`step_size` must be integer-valued for INT parameter "
f"{self._name}. Got: {step_size}."
)
self._step_size = float(step_size)

def _validate_step_size_on_grid(
self, lower: TNumeric | None = None, upper: TNumeric | None = None
) -> None:
"""Validate that both bounds lie on the ``step_size`` grid.

The grid is anchored at ``lower``, so ``lower`` is always on it. This
additionally requires ``upper`` to be on the grid, i.e. that
``(upper - lower)`` is an integer multiple of ``step_size`` (within
``EPS``). This guarantees ``upper`` is itself a feasible value, so a
value near the upper bound snaps to ``upper`` rather than to a grid
point short of it.

Args:
lower: Lower bound to validate against. Defaults to ``self._lower``.
upper: Upper bound to validate against. Defaults to ``self._upper``.
These overrides let callers validate prospective bounds before
committing them.

Raises:
UserInputError: If ``upper`` does not lie on the grid.
"""
if self._step_size is None:
return
lower = self._lower if lower is None else lower
upper = self._upper if upper is None else upper
step_size = none_throws(self._step_size)
width = float(upper) - float(lower)
n = width / step_size
if abs(n - round(n)) * step_size > EPS:
raise UserInputError(
f"`step_size` must evenly divide the range of parameter "
f"{self._name}: (upper - lower) = {width} is not an integer "
f"multiple of step_size = {step_size}. Adjust the bounds or "
f"step_size so that both bounds lie on the grid."
)

def set_log_scale(self, log_scale: bool) -> RangeParameter:
self._log_scale = log_scale
return self
Expand Down Expand Up @@ -647,6 +801,7 @@ def clone(self) -> RangeParameter:
log_scale=self._log_scale,
logit_scale=self._logit_scale,
digits=self._digits,
step_size=self._step_size,
is_fidelity=self._is_fidelity,
target_value=self._target_value,
backfill_value=self._backfill_value,
Expand All @@ -657,13 +812,50 @@ def cast(self, value: TParamValue) -> TNumeric:
value = super().cast(value=value)
if self.parameter_type is ParameterType.FLOAT and self._digits is not None:
return round(float(value), none_throws(self._digits))
# Skip snapping while the constructor is still casting the bounds
# themselves (before both ``self._lower`` and ``self._upper`` are set):
# the bounds anchor the grid and must not be snapped (``upper`` is only
# validated to be on the grid after both are assigned). ``_snap_to_grid``
# needs ``self._lower``; gating on ``self._upper`` too is what excludes
# the ``upper`` cast at construction.
if (
self._step_size is not None
and getattr(self, "_lower", None) is not None
and getattr(self, "_upper", None) is not None
):
value = self._snap_to_grid(value=float(value))
return assert_is_instance(value, TNumeric)

def _snap_to_grid(self, value: float) -> TNumeric:
"""Snap ``value`` to the nearest grid point.

The grid is ``{lower + k * step_size : k in Z}``. The nearest grid point
is found by rounding ``(value - lower) / step_size`` to the nearest
integer. The result is *not* clamped to ``[lower, upper]``: an
out-of-bounds input (e.g. historical observations recorded outside the
current bounds) snaps to the nearest grid point, which may itself lie
outside the bounds. This mirrors the non-``step_size`` ``cast()``, which
leaves out-of-bounds values untouched rather than silently moving them
into range -- range validity is enforced by ``validate()``, not by
``cast()``. For INT parameters the snapped value is integer-valued
(``step_size`` is validated to be an integer), so it is returned as an
``int``.
"""
step_size = none_throws(self._step_size)
lower = float(self._lower)
n = round((value - lower) / step_size)
snapped = lower + n * step_size
if self.parameter_type is ParameterType.INT:
return int(round(snapped))
return snapped

def __repr__(self) -> str:
ret_val = self._base_repr()

if self._digits is not None:
ret_val += f", digits={self._digits}"
if self._step_size is not None:
ret_val += f", step_size={self._step_size}"

return ret_val + ")"

Expand Down
Loading
Loading