diff --git a/CHANGELOG.md b/CHANGELOG.md index 680f11da28..c7715778a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate` - `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories - `parameter_names` attribute to basic kernels for controlling the considered parameters +- `ParameterKind` flag enum for classifying parameters by their role and automatic + parameter kind validation in kernel factories - `IndexKernel` and `PositiveIndexKernel` classes - Interpoint constraints for continuous search spaces - `IndexKernel` and `PositiveIndexKernel` classes diff --git a/baybe/parameters/base.py b/baybe/parameters/base.py index 2d4df2bc77..8fbd489a41 100644 --- a/baybe/parameters/base.py +++ b/baybe/parameters/base.py @@ -21,6 +21,7 @@ from baybe.utils.metadata import MeasurableMetadata, to_metadata if TYPE_CHECKING: + from baybe.parameters.enum import _ParameterKind from baybe.searchspace.continuous import SubspaceContinuous from baybe.searchspace.core import SearchSpace from baybe.searchspace.discrete import SubspaceDiscrete @@ -77,6 +78,13 @@ def is_discrete(self) -> bool: """Boolean indicating if this is a discrete parameter.""" return isinstance(self, DiscreteParameter) + @property + def _kind(self) -> _ParameterKind: + """The kind of the parameter.""" + from baybe.parameters.enum import _ParameterKind + + return _ParameterKind.from_parameter(self) + @property @abstractmethod def comp_rep_columns(self) -> tuple[str, ...]: diff --git a/baybe/parameters/enum.py b/baybe/parameters/enum.py index 3161f67dc7..622fa4af58 100644 --- a/baybe/parameters/enum.py +++ b/baybe/parameters/enum.py @@ -1,6 +1,38 @@ """Parameter-related enumerations.""" -from enum import Enum +from __future__ import annotations + +from enum import Enum, Flag, auto +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + + +class _ParameterKind(Flag): + """Flag enum encoding the kind of a parameter. + + Can be used to express compatibility (e.g. Gaussian process kernel factories) + with different parameter types via bitwise combination of flags. + """ + + REGULAR = auto() + """Regular parameter undergoing no special treatment.""" + + TASK = auto() + """Task parameter for transfer learning.""" + + FIDELITY = auto() + """Fidelity parameter for multi-fidelity modelling.""" + + @staticmethod + def from_parameter(parameter: Parameter) -> _ParameterKind: + """Determine the kind of a parameter from its type.""" + from baybe.parameters.categorical import TaskParameter + + if isinstance(parameter, TaskParameter): + return _ParameterKind.TASK + return _ParameterKind.REGULAR class ParameterEncoding(Enum): diff --git a/baybe/parameters/selectors.py b/baybe/parameters/selectors.py index 6b1a4ae16d..e72b6fe4d9 100644 --- a/baybe/parameters/selectors.py +++ b/baybe/parameters/selectors.py @@ -3,15 +3,13 @@ import re from abc import ABC, abstractmethod from collections.abc import Collection -from typing import ClassVar, Protocol +from typing import Protocol from attrs import Converter, define, field -from attrs.converters import optional from attrs.validators import deep_iterable, instance_of, min_len from typing_extensions import override from baybe.parameters.base import Parameter -from baybe.searchspace.core import SearchSpace from baybe.utils.basic import to_tuple from baybe.utils.conversion import nonstring_to_tuple @@ -131,37 +129,3 @@ def to_parameter_selector( return TypeSelector(items) raise TypeError(f"Cannot convert {x!r} to a parameter selector.") - - -@define -class _ParameterSelectorMixin: - """A mixin class to enable parameter selection.""" - - # For internal use only: sanity check mechanism to remind developers of new - # subclasses to actually use the parameter selector when it is provided - # TODO: Perhaps we can find a more elegant way to enforce this by design - _uses_parameter_names: ClassVar[bool] = False - - parameter_selector: ParameterSelectorProtocol | None = field( - default=None, converter=optional(to_parameter_selector), kw_only=True - ) - """An optional selector to specify which parameters are to be considered.""" - - def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None: - """Get the names of the parameters to be considered.""" - if self.parameter_selector is None: - return None - - return tuple( - p.name for p in searchspace.parameters if self.parameter_selector(p) - ) - - def __attrs_post_init__(self): - if self.parameter_selector is not None and not self._uses_parameter_names: - raise AssertionError( - f"A `parameter_selector` was provided to " - f"`{type(self).__name__}`, but the class does not set " - f"`_uses_parameter_names = True`. Subclasses that accept a " - f"parameter selector must explicitly set this flag to confirm " - f"they actually use the selected parameter names." - ) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index ca4da066e7..8b6ddff280 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -2,18 +2,24 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from collections.abc import Iterable from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from attrs import define, field +from attrs.converters import optional from attrs.validators import is_callable from typing_extensions import override +from baybe.exceptions import IncompatibleSearchSpaceError from baybe.kernels.base import Kernel -from baybe.kernels.composite import ProductKernel from baybe.parameters.categorical import TaskParameter +from baybe.parameters.enum import _ParameterKind from baybe.parameters.selectors import ( + ParameterSelectorProtocol, TypeSelector, + to_parameter_selector, ) from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.components.generic import ( @@ -27,6 +33,8 @@ from gpytorch.kernels import Kernel as GPyTorchKernel from torch import Tensor + from baybe.parameters.base import Parameter + KernelFactoryProtocol = GPComponentFactoryProtocol[Kernel | GPyTorchKernel] PlainKernelFactory = PlainGPComponentFactory[Kernel | GPyTorchKernel] else: @@ -36,7 +44,94 @@ @define -class ICMKernelFactory(KernelFactoryProtocol): +class _PureKernelFactory(KernelFactoryProtocol, ABC): + """Base class for pure kernel factories.""" + + # For internal use only: sanity check mechanism to remind developers of new + # factories to actually use the parameter selector when it is provided + # TODO: Perhaps we can find a more elegant way to enforce this by design + _uses_parameter_names: ClassVar[bool] = False + + _supported_parameter_kinds: ClassVar[_ParameterKind] = _ParameterKind.REGULAR + """The parameter kinds supported by the kernel factory.""" + + parameter_selector: ParameterSelectorProtocol | None = field( + default=None, converter=optional(to_parameter_selector) + ) + """An optional selector to specify which parameters are considered by the kernel.""" + + def __attrs_post_init__(self): + if self.parameter_selector is not None and not self._uses_parameter_names: + raise AssertionError( + f"A `parameter_selector` was provided to " + f"`{type(self).__name__}`, but the class does not set " + f"`_uses_parameter_names = True`. Subclasses that accept a " + f"parameter selector must explicitly set this flag to confirm " + f"they actually use the selected parameter names." + ) + + def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None: + """Get the names of the parameters to be considered by the kernel.""" + if self.parameter_selector is None: + return None + + return tuple( + p.name for p in searchspace.parameters if self.parameter_selector(p) + ) + + def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None: + """Validate that the given parameters are supported by the factory. + + Args: + parameters: The parameters to validate. + + Raises: + IncompatibleSearchSpaceError: If unsupported parameter kinds are found. + """ + if unsupported := [ + p.name + for p in parameters + if not (p._kind & self._supported_parameter_kinds) + ]: + raise IncompatibleSearchSpaceError( + f"'{type(self).__name__}' does not support parameter kind(s) for " + f"parameter(s) {unsupported}. Supported kinds: " + f"{self._supported_parameter_kinds}." + ) + + @override + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> Kernel: + """Construct the kernel, validating parameter kinds before construction.""" + if self.parameter_selector is not None: + params = [p for p in searchspace.parameters if self.parameter_selector(p)] + else: + params = list(searchspace.parameters) + self._validate_parameter_kinds(params) + + return self._make(searchspace, train_x, train_y) + + @abstractmethod + def _make( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> Kernel: + """Construct the kernel.""" + + +@define +class _MetaKernelFactory(KernelFactoryProtocol, ABC): + """Base class for meta kernel factories that orchestrate other kernel factories.""" + + @override + @abstractmethod + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> Kernel: ... + + +@define +class ICMKernelFactory(_MetaKernelFactory): """A kernel factory that constructs an ICM kernel for transfer learning. ICM: Intrinsic Coregionalization Model :cite:p:`NIPS2007_66368270` @@ -78,4 +173,8 @@ def __call__( ) -> Kernel: base_kernel = self.base_kernel_factory(searchspace, train_x, train_y) task_kernel = self.task_kernel_factory(searchspace, train_x, train_y) - return ProductKernel([base_kernel, task_kernel]) + if isinstance(base_kernel, Kernel): + base_kernel = base_kernel.to_gpytorch(searchspace) + if isinstance(task_kernel, Kernel): + task_kernel = task_kernel.to_gpytorch(searchspace) + return base_kernel * task_kernel diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index bc820ae52b..690fc318cd 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -10,14 +10,14 @@ from baybe.kernels.base import Kernel from baybe.kernels.basic import IndexKernel from baybe.parameters.categorical import TaskParameter +from baybe.parameters.enum import _ParameterKind from baybe.parameters.selectors import ( ParameterSelectorProtocol, TypeSelector, - _ParameterSelectorMixin, to_parameter_selector, ) from baybe.searchspace.core import SearchSpace -from baybe.surrogates.gaussian_process.components.kernel import KernelFactoryProtocol +from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( SmoothedEDBOKernelFactory, @@ -29,11 +29,16 @@ @define -class BayBEKernelFactory(KernelFactoryProtocol): +class BayBEKernelFactory(_PureKernelFactory): """The default kernel factory for Gaussian process surrogates.""" + _supported_parameter_kinds: ClassVar[_ParameterKind] = ( + _ParameterKind.REGULAR | _ParameterKind.TASK + ) + # See base class. + @override - def __call__( + def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory @@ -48,12 +53,15 @@ def __call__( @define -class BayBETaskKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin): +class BayBETaskKernelFactory(_PureKernelFactory): """The factory providing the default task kernel for Gaussian process surrogates.""" _uses_parameter_names: ClassVar[bool] = True # See base class. + _supported_parameter_kinds: ClassVar[_ParameterKind] = _ParameterKind.TASK + # See base class. + parameter_selector: ParameterSelectorProtocol | None = field( factory=lambda: TypeSelector([TaskParameter]), converter=to_parameter_selector, @@ -61,7 +69,7 @@ class BayBETaskKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin): # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) @override - def __call__( + def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: return IndexKernel( diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 36611c258e..6db0af8a30 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -16,13 +16,14 @@ from baybe.parameters.selectors import ( ParameterSelectorProtocol, TypeSelector, - _ParameterSelectorMixin, to_parameter_selector, ) from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete -from baybe.surrogates.gaussian_process.components.kernel import KernelFactoryProtocol +from baybe.surrogates.gaussian_process.components.kernel import ( + _PureKernelFactory, +) from baybe.surrogates.gaussian_process.components.likelihood import ( LikelihoodFactoryProtocol, ) @@ -56,7 +57,7 @@ def _contains_encoding( @define -class EDBOKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin): +class EDBOKernelFactory(_PureKernelFactory): """A factory providing EDBO kernels. References: @@ -74,12 +75,10 @@ class EDBOKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin): # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) @override - def __call__( + def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] - ) + effective_dims = train_x.shape[-1] switching_condition = _contains_encoding( searchspace.discrete, _EDBO_ENCODINGS diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index eb8d57491a..1c2718adb6 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -15,11 +15,12 @@ from baybe.parameters.selectors import ( ParameterSelectorProtocol, TypeSelector, - _ParameterSelectorMixin, to_parameter_selector, ) from baybe.priors.basic import GammaPrior -from baybe.surrogates.gaussian_process.components.kernel import KernelFactoryProtocol +from baybe.surrogates.gaussian_process.components.kernel import ( + _PureKernelFactory, +) from baybe.surrogates.gaussian_process.components.likelihood import ( LikelihoodFactoryProtocol, ) @@ -38,7 +39,7 @@ @define -class SmoothedEDBOKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin): +class SmoothedEDBOKernelFactory(_PureKernelFactory): """A factory providing smoothed versions of EDBO kernels. Takes the low and high dimensional limits of @@ -56,12 +57,10 @@ class SmoothedEDBOKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin): # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) @override - def __call__( + def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] - ) + effective_dims = train_x.shape[-1] # Interpolate prior moments linearly between low D and high D regime. # The high D regime itself is the average of the EDBO OHE and Mordred regime. diff --git a/docs/conf.py b/docs/conf.py index 5251bb8654..11ff636a77 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -146,11 +146,9 @@ (r"py:obj", "baybe.utils.boolean.UncertainBool.*"), ("py:obj", "baybe.targets.botorch.*"), ("py:obj", "baybe.objectives.botorch.*"), - ("py:class", "baybe.parameters.base._DiscreteLabelLikeParameter"), - ("py:class", "baybe.acquisition.acqfs._ExpectedHypervolumeImprovement"), - ("py:class", "baybe.settings._SlottedContextDecorator"), ("py:class", "baybe.surrogates.gaussian_process.components.PlainKernelFactory"), - ("py:class", "baybe.parameters.selectors._ParameterSelectorMixin"), + # Private classes + (r"py:class", r"baybe\..*\._.*"), # Deprecation ("py:.*", "baybe.targets._deprecated.*"), ] diff --git a/tests/test_kernel_factories.py b/tests/test_kernel_factories.py new file mode 100644 index 0000000000..6a8acda6a6 --- /dev/null +++ b/tests/test_kernel_factories.py @@ -0,0 +1,75 @@ +"""Tests for kernel factories.""" + +from contextlib import nullcontext + +import pytest +import torch +from pytest import param + +from baybe.exceptions import IncompatibleSearchSpaceError +from baybe.parameters.categorical import CategoricalParameter, TaskParameter +from baybe.parameters.numerical import ( + NumericalContinuousParameter, + NumericalDiscreteParameter, +) +from baybe.searchspace.core import SearchSpace +from baybe.surrogates.gaussian_process.presets.baybe import ( + BayBEKernelFactory, + BayBENumericalKernelFactory, + BayBETaskKernelFactory, +) + +# A selector that accepts all parameters +_SELECT_ALL = lambda parameter: True # noqa: E731 + + +@pytest.mark.parametrize( + ("factory", "parameters", "error"), + [ + param( + BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL), + [TaskParameter("task", ["t1", "t2"])], + IncompatibleSearchSpaceError, + id="regular_rejects_task", + ), + param( + BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + [CategoricalParameter("cat", ["a", "b"])], + IncompatibleSearchSpaceError, + id="task_rejects_categorical", + ), + param( + BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + [NumericalDiscreteParameter("num", [1, 2, 3])], + IncompatibleSearchSpaceError, + id="task_rejects_numerical_discrete", + ), + param( + BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + [NumericalContinuousParameter("cont", (0, 1))], + IncompatibleSearchSpaceError, + id="task_rejects_numerical_continuous", + ), + param( + BayBEKernelFactory(), + [ + NumericalContinuousParameter("cont", (0, 1)), + TaskParameter("task", ["t1", "t2"]), + ], + None, + id="combined_accepts_both", + ), + ], +) +def test_factory_parameter_kind_validation(factory, parameters, error): + """Factories reject unsupported parameter kinds and accept supported ones.""" + ss = SearchSpace.from_product(parameters) + train_x = torch.zeros(2, len(ss.comp_rep_columns)) + train_y = torch.zeros(2, 1) + + with ( + nullcontext() + if error is None + else pytest.raises(error, match="does not support") + ): + factory(ss, train_x, train_y)