diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 8b6ddff280..6b2a52eb97 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools from abc import ABC, abstractmethod from collections.abc import Iterable from functools import partial @@ -22,6 +23,7 @@ to_parameter_selector, ) from baybe.searchspace.core import SearchSpace +from baybe.serialization.mixin import SerialMixin from baybe.surrogates.gaussian_process.components.generic import ( GPComponentFactoryProtocol, GPComponentType, @@ -44,7 +46,7 @@ @define -class _PureKernelFactory(KernelFactoryProtocol, ABC): +class _PureKernelFactory(KernelFactoryProtocol, SerialMixin, ABC): """Base class for pure kernel factories.""" # For internal use only: sanity check mechanism to remind developers of new @@ -79,6 +81,15 @@ def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | Non p.name for p in searchspace.parameters if self.parameter_selector(p) ) + def _get_effective_dimensionality(self, searchspace: SearchSpace) -> int: + """Get the number of computational columns for the selected parameters.""" + names = self.get_parameter_names(searchspace) + if names is None: + return len(searchspace.comp_rep_columns) + return sum( + len(searchspace.get_comp_rep_parameter_indices(name)) for name in names + ) + def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None: """Validate that the given parameters are supported by the factory. @@ -119,6 +130,85 @@ def _make( """Construct the kernel.""" +def _enable_transfer_learning( + cls: type[_PureKernelFactory], name: str | None = None, / +) -> type[_PureKernelFactory]: + """Class decorator enabling BayBE's default transfer learning mechanism. + + When the search space contains a task parameter, the decorated factory + automatically composes its kernel with BayBE's default task kernel. + Otherwise, the factory behaves unchanged. + + When used as a decorator (without ``name``), the class is modified in-place. + When called with a ``name`` argument, a new subclass is created so that the + original class remains unmodified. The latter form is intended for cases where + the original class is reused independently elsewhere. + + Args: + cls: The kernel factory class to decorate. + name: Optional name for the created class. If provided, a new subclass is + created instead of modifying ``cls`` in-place. + + Raises: + TypeError: If the factory already supports task parameters. + + Returns: + The decorated kernel factory class with transfer learning enabled. + """ + if cls._supported_parameter_kinds & _ParameterKind.TASK: + raise TypeError(f"'{cls.__name__}' already supports task parameters.") + + # This distinction is important for serialization so that the classes can be + # correctly identified by their names in the subclass registry + if name is not None: + # Create a subclass so the original class remains unmodified. + # __module__ must be set explicitly because the Protocol metaclass + # would otherwise default it to "abc". + target_cls = type( + name, (cls,), {"__doc__": cls.__doc__, "__module__": cls.__module__} + ) + else: + # Modify the class in-place (avoids name collision in subclass registry) + target_cls = cls + + original_call = cls.__call__ + original_supported_kinds = cls._supported_parameter_kinds + _task_exclude_selector = TypeSelector((TaskParameter,), exclude=True) + + @functools.wraps(original_call) + def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): + # Temporarily narrow the supported parameter kinds to those of the original + # class. If the decorator logic is correct, the original factory should never + # see the extended scope, but this acts as a sanity check to prevent regressions + broadened_kinds = target_cls._supported_parameter_kinds # type: ignore[attr-defined] + target_cls._supported_parameter_kinds = original_supported_kinds # type: ignore[attr-defined] + + # Split off the task parameters + original_selector = self.parameter_selector + if original_selector is None: + self.parameter_selector = _task_exclude_selector + else: + self.parameter_selector = lambda p: ( + _task_exclude_selector(p) and original_selector(p) + ) + try: + base_kernel = original_call(self, searchspace, train_x, train_y) + finally: + target_cls._supported_parameter_kinds = broadened_kinds # type: ignore[attr-defined] + self.parameter_selector = original_selector + + if searchspace.task_idx is not None: + icm = ICMKernelFactory(base_kernel_or_factory=base_kernel) + return icm(searchspace, train_x, train_y) + return base_kernel + + target_cls.__call__ = __call__ # type: ignore[method-assign] + target_cls._supported_parameter_kinds = ( # type: ignore[attr-defined] + cls._supported_parameter_kinds | _ParameterKind.TASK + ) + return target_cls + + @define class _MetaKernelFactory(KernelFactoryProtocol, ABC): """Base class for meta kernel factories that orchestrate other kernel factories.""" @@ -154,18 +244,25 @@ class ICMKernelFactory(_MetaKernelFactory): @base_kernel_factory.default def _default_base_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( - BayBENumericalKernelFactory, + _BayBENumericalKernelFactory, ) - return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True)) + assert ( + _BayBENumericalKernelFactory._supported_parameter_kinds + is _ParameterKind.REGULAR + ) + return _BayBENumericalKernelFactory( + TypeSelector((TaskParameter,), exclude=True) + ) @task_kernel_factory.default def _default_task_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( - BayBETaskKernelFactory, + _BayBETaskKernelFactory, ) - return BayBETaskKernelFactory(TypeSelector((TaskParameter,))) + assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK + return _BayBETaskKernelFactory() @override def __call__( diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 690fc318cd..f5cfe65941 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -22,39 +22,24 @@ from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( SmoothedEDBOKernelFactory, SmoothedEDBOLikelihoodFactory, + _SmoothedEDBONumericalKernelFactory, ) if TYPE_CHECKING: from torch import Tensor -@define -class BayBEKernelFactory(_PureKernelFactory): - """The default kernel factory for Gaussian process surrogates.""" - - _supported_parameter_kinds: ClassVar[_ParameterKind] = ( - _ParameterKind.REGULAR | _ParameterKind.TASK - ) - # See base class. - - @override - def _make( - self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor - ) -> Kernel: - from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory +class _BayBENumericalKernelFactory(_SmoothedEDBONumericalKernelFactory): + """The default numerical kernel factory for GP surrogates.""" - is_multitask = searchspace.task_idx is not None - factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory - return factory()(searchspace, train_x, train_y) - -BayBENumericalKernelFactory = SmoothedEDBOKernelFactory -"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 +class BayBEKernelFactory(SmoothedEDBOKernelFactory): # type: ignore[valid-type, misc] + """The default kernel factory for GP surrogates.""" @define -class BayBETaskKernelFactory(_PureKernelFactory): - """The factory providing the default task kernel for Gaussian process surrogates.""" +class _BayBETaskKernelFactory(_PureKernelFactory): + """The default task kernel factory for GP surrogates.""" _uses_parameter_names: ClassVar[bool] = True # See base class. @@ -79,11 +64,13 @@ def _make( ) -BayBEMeanFactory = LazyConstantMeanFactory -"""The factory providing the default mean function for Gaussian process surrogates.""" +class BayBEMeanFactory(LazyConstantMeanFactory): + """The default mean factory for GP surrogates.""" + + +class BayBELikelihoodFactory(SmoothedEDBOLikelihoodFactory): + """The default likelihood factory for GP surrogates.""" -BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory -"""The factory providing the default likelihood for Gaussian process surrogates.""" # Aliases for generic preset imports PresetKernelFactory = BayBEKernelFactory diff --git a/baybe/surrogates/gaussian_process/presets/chen.py b/baybe/surrogates/gaussian_process/presets/chen.py index 5e90c4aa5c..b130a3f3a7 100644 --- a/baybe/surrogates/gaussian_process/presets/chen.py +++ b/baybe/surrogates/gaussian_process/presets/chen.py @@ -6,19 +6,14 @@ import math from typing import TYPE_CHECKING, ClassVar -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters.categorical import TaskParameter -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) from baybe.priors.basic import GammaPrior from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -33,6 +28,7 @@ from baybe.searchspace.core import SearchSpace +@_enable_transfer_learning @define class CHENKernelFactory(_PureKernelFactory): """A factory providing adaptive hyperprior kernels as proposed by :cite:p:`Chen2026`.""" # noqa: E501 @@ -40,17 +36,12 @@ class CHENKernelFactory(_PureKernelFactory): _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - lengthscale = 0.4 * math.sqrt(train_x.shape[-1]) + 4.0 + n_dimensions = self._get_effective_dimensionality(searchspace) + lengthscale = 0.4 * math.sqrt(n_dimensions) + 4.0 lengthscale_prior = GammaPrior(2.0 * lengthscale, 2.0) lengthscale_initial_value = lengthscale outputscale_prior = GammaPrior(1.0 * lengthscale, 1.0) diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 539e2ef72c..7a3f569220 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -6,22 +6,17 @@ from collections.abc import Collection from typing import TYPE_CHECKING, ClassVar -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters import TaskParameter -from baybe.parameters.enum import SubstanceEncoding -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) +from baybe.parameters.enum import SubstanceEncoding, _ParameterKind from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -56,6 +51,7 @@ def _contains_encoding( """Encodings relevant to EDBO logic.""" +@_enable_transfer_learning @define class EDBOKernelFactory(_PureKernelFactory): """A factory providing EDBO kernels, as proposed by :cite:p:`Shields2021`. @@ -67,17 +63,11 @@ class EDBOKernelFactory(_PureKernelFactory): _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] + effective_dims = self._get_effective_dimensionality(searchspace) switching_condition = _contains_encoding( searchspace.discrete, _EDBO_ENCODINGS @@ -123,8 +113,8 @@ def _make( ) -EDBOMeanFactory = LazyConstantMeanFactory -"""A factory providing mean functions for the EDBO preset.""" +class EDBOMeanFactory(LazyConstantMeanFactory): + """A factory providing mean functions for the EDBO preset.""" @define @@ -142,8 +132,10 @@ def __call__( import torch from gpytorch.likelihoods import GaussianLikelihood - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] + effective_dims = sum( + len(searchspace.get_comp_rep_parameter_indices(p.name)) + for p in searchspace.parameters + if p._kind & _ParameterKind.REGULAR ) switching_condition = _contains_encoding( diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index 904f711f31..59133f8cc9 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -6,19 +6,15 @@ from typing import TYPE_CHECKING, ClassVar import numpy as np -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters import TaskParameter -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) +from baybe.parameters.enum import _ParameterKind from baybe.priors.basic import GammaPrior from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -39,28 +35,17 @@ @define -class SmoothedEDBOKernelFactory(_PureKernelFactory): - """A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`). - - Takes the low and high dimensional limits of - :class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory` - and interpolates the prior moments linearly in between. - """ # noqa: E501 +class _SmoothedEDBONumericalKernelFactory(_PureKernelFactory): + """A factory providing the core numerical kernel for the smoothed EDBO preset.""" _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] + effective_dims = self._get_effective_dimensionality(searchspace) # Interpolate prior moments linearly between low D and high D regime. # The high D regime itself is the average of the EDBO OHE and Mordred regime. @@ -88,8 +73,19 @@ def _make( ) -SmoothedEDBOMeanFactory = LazyConstantMeanFactory -"""A factory providing mean functions for the smoothed EDBO preset.""" +SmoothedEDBOKernelFactory = _enable_transfer_learning( + _SmoothedEDBONumericalKernelFactory, "SmoothedEDBOKernelFactory" +) +"""A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`). + +Takes the low and high dimensional limits of +:class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory` +and interpolates the prior moments linearly in between. +""" # noqa: E501 + + +class SmoothedEDBOMeanFactory(LazyConstantMeanFactory): + """A factory providing mean functions for the smoothed EDBO preset.""" @define @@ -111,8 +107,10 @@ def __call__( # Interpolate prior moments linearly between low D and high D regime. # The high D regime itself is the average of the EDBO OHE and Mordred regime. # Values outside the dimension limits will get the border value assigned. - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] + effective_dims = sum( + len(searchspace.get_comp_rep_parameter_indices(p.name)) + for p in searchspace.parameters + if p._kind & _ParameterKind.REGULAR ) prior = GammaPrior( diff --git a/tests/serialization/test_kernel_factory_serialization.py b/tests/serialization/test_kernel_factory_serialization.py new file mode 100644 index 0000000000..0bdc20af93 --- /dev/null +++ b/tests/serialization/test_kernel_factory_serialization.py @@ -0,0 +1,23 @@ +"""Kernel factory serialization tests.""" + +import pytest + +from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory +from baybe.surrogates.gaussian_process.presets import * # noqa: F401, F403 +from baybe.utils.basic import get_subclasses +from tests.serialization.utils import assert_roundtrip_consistency + +_KERNEL_FACTORIES = [ + cls + for cls in get_subclasses(_PureKernelFactory) + if not cls.__name__.startswith("_") +] + + +@pytest.mark.parametrize( + "factory", + [pytest.param(cls(), id=cls.__name__) for cls in _KERNEL_FACTORIES], +) +def test_roundtrip(factory): + """A serialization roundtrip yields an equivalent object.""" + assert_roundtrip_consistency(factory) diff --git a/tests/test_kernel_factories.py b/tests/test_kernel_factories.py index 6a8acda6a6..45f3ccc109 100644 --- a/tests/test_kernel_factories.py +++ b/tests/test_kernel_factories.py @@ -15,8 +15,8 @@ from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.presets.baybe import ( BayBEKernelFactory, - BayBENumericalKernelFactory, - BayBETaskKernelFactory, + _BayBENumericalKernelFactory, + _BayBETaskKernelFactory, ) # A selector that accepts all parameters @@ -27,25 +27,25 @@ ("factory", "parameters", "error"), [ param( - BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL), + _BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL), [TaskParameter("task", ["t1", "t2"])], IncompatibleSearchSpaceError, id="regular_rejects_task", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [CategoricalParameter("cat", ["a", "b"])], IncompatibleSearchSpaceError, id="task_rejects_categorical", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [NumericalDiscreteParameter("num", [1, 2, 3])], IncompatibleSearchSpaceError, id="task_rejects_numerical_discrete", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [NumericalContinuousParameter("cont", (0, 1))], IncompatibleSearchSpaceError, id="task_rejects_numerical_continuous",