From 0aa7b11293bbd2668fe2f7b6c1bd78c3bd80309a Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 09:48:53 +0200 Subject: [PATCH 1/9] Extract transfer learning mechanism into a reusable decorator * Provides a single source of truth for defining the TL logic * Enables TL for non-TL presets by applying the decorator --- .../gaussian_process/components/kernel.py | 57 +++++++++++++++++-- .../gaussian_process/presets/baybe.py | 27 ++------- .../gaussian_process/presets/chen.py | 16 +----- .../gaussian_process/presets/edbo.py | 15 +---- .../gaussian_process/presets/edbo_smoothed.py | 33 +++++------ tests/test_kernel_factories.py | 12 ++-- 6 files changed, 85 insertions(+), 75 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 8b6ddff280..8d6cec478b 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools from abc import ABC, abstractmethod from collections.abc import Iterable from functools import partial @@ -119,6 +120,47 @@ def _make( """Construct the kernel.""" +def _enable_transfer_learning( + cls: type[_PureKernelFactory], / +) -> type[_PureKernelFactory]: + """Class decorator enabling BayBE's default transfer learning mechanism. + + When the search space contains a task parameter, the decorated factory + automatically composes its kernel with BayBE's default task kernel. + Otherwise, the factory behaves unchanged. + + Args: + cls: The kernel factory class to decorate. + + Raises: + TypeError: If the factory already supports task parameters. + + Returns: + The decorated kernel factory class with transfer learning enabled. + """ + if cls._supported_parameter_kinds & _ParameterKind.TASK: + raise TypeError(f"'{cls.__name__}' already supports task parameters.") + + # Create a subclass so the original class remains unmodified + new_cls = type(cls.__name__, (cls,), {"__doc__": cls.__doc__}) + + original_call = cls.__call__ + + @functools.wraps(original_call) + def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): + base_kernel = original_call(self, searchspace, train_x, train_y) + if searchspace.task_idx is not None: + icm = ICMKernelFactory(base_kernel_or_factory=base_kernel) + return icm(searchspace, train_x, train_y) + return base_kernel + + new_cls.__call__ = __call__ # type: ignore[method-assign] + new_cls._supported_parameter_kinds = ( + cls._supported_parameter_kinds | _ParameterKind.TASK + ) + return new_cls + + @define class _MetaKernelFactory(KernelFactoryProtocol, ABC): """Base class for meta kernel factories that orchestrate other kernel factories.""" @@ -154,18 +196,25 @@ class ICMKernelFactory(_MetaKernelFactory): @base_kernel_factory.default def _default_base_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( - BayBENumericalKernelFactory, + _BayBENumericalKernelFactory, ) - return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True)) + assert ( + _BayBENumericalKernelFactory._supported_parameter_kinds + is _ParameterKind.REGULAR + ) + return _BayBENumericalKernelFactory( + TypeSelector((TaskParameter,), exclude=True) + ) @task_kernel_factory.default def _default_task_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( - BayBETaskKernelFactory, + _BayBETaskKernelFactory, ) - return BayBETaskKernelFactory(TypeSelector((TaskParameter,))) + assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK + return _BayBETaskKernelFactory() @override def __call__( diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 690fc318cd..0fb3ba281f 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -22,38 +22,23 @@ from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( SmoothedEDBOKernelFactory, SmoothedEDBOLikelihoodFactory, + _SmoothedEDBONumericalKernelFactory, ) if TYPE_CHECKING: from torch import Tensor -@define -class BayBEKernelFactory(_PureKernelFactory): - """The default kernel factory for Gaussian process surrogates.""" - - _supported_parameter_kinds: ClassVar[_ParameterKind] = ( - _ParameterKind.REGULAR | _ParameterKind.TASK - ) - # See base class. - - @override - def _make( - self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor - ) -> Kernel: - from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory - - is_multitask = searchspace.task_idx is not None - factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory - return factory()(searchspace, train_x, train_y) +_BayBENumericalKernelFactory = _SmoothedEDBONumericalKernelFactory +"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 -BayBENumericalKernelFactory = SmoothedEDBOKernelFactory -"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 +BayBEKernelFactory = SmoothedEDBOKernelFactory +"""The default kernel factory for Gaussian process surrogates.""" @define -class BayBETaskKernelFactory(_PureKernelFactory): +class _BayBETaskKernelFactory(_PureKernelFactory): """The factory providing the default task kernel for Gaussian process surrogates.""" _uses_parameter_names: ClassVar[bool] = True diff --git a/baybe/surrogates/gaussian_process/presets/chen.py b/baybe/surrogates/gaussian_process/presets/chen.py index 5e90c4aa5c..91dbc82f12 100644 --- a/baybe/surrogates/gaussian_process/presets/chen.py +++ b/baybe/surrogates/gaussian_process/presets/chen.py @@ -6,19 +6,14 @@ import math from typing import TYPE_CHECKING, ClassVar -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters.categorical import TaskParameter -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) from baybe.priors.basic import GammaPrior from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -33,6 +28,7 @@ from baybe.searchspace.core import SearchSpace +@_enable_transfer_learning @define class CHENKernelFactory(_PureKernelFactory): """A factory providing adaptive hyperprior kernels as proposed by :cite:p:`Chen2026`.""" # noqa: E501 @@ -40,12 +36,6 @@ class CHENKernelFactory(_PureKernelFactory): _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 539e2ef72c..5569ea3630 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -6,22 +6,18 @@ from collections.abc import Collection from typing import TYPE_CHECKING, ClassVar -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel from baybe.parameters import TaskParameter from baybe.parameters.enum import SubstanceEncoding -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -56,6 +52,7 @@ def _contains_encoding( """Encodings relevant to EDBO logic.""" +@_enable_transfer_learning @define class EDBOKernelFactory(_PureKernelFactory): """A factory providing EDBO kernels, as proposed by :cite:p:`Shields2021`. @@ -67,12 +64,6 @@ class EDBOKernelFactory(_PureKernelFactory): _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index 904f711f31..785ecb868a 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -6,19 +6,15 @@ from typing import TYPE_CHECKING, ClassVar import numpy as np -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel from baybe.parameters import TaskParameter -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) from baybe.priors.basic import GammaPrior from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -39,23 +35,12 @@ @define -class SmoothedEDBOKernelFactory(_PureKernelFactory): - """A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`). - - Takes the low and high dimensional limits of - :class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory` - and interpolates the prior moments linearly in between. - """ # noqa: E501 +class _SmoothedEDBONumericalKernelFactory(_PureKernelFactory): + """A factory providing the core numerical kernel for the smoothed EDBO preset.""" _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor @@ -88,6 +73,16 @@ def _make( ) +SmoothedEDBOKernelFactory = _enable_transfer_learning( + _SmoothedEDBONumericalKernelFactory +) +"""A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`). + +Takes the low and high dimensional limits of +:class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory` +and interpolates the prior moments linearly in between. +""" # noqa: E501 + SmoothedEDBOMeanFactory = LazyConstantMeanFactory """A factory providing mean functions for the smoothed EDBO preset.""" diff --git a/tests/test_kernel_factories.py b/tests/test_kernel_factories.py index 6a8acda6a6..45f3ccc109 100644 --- a/tests/test_kernel_factories.py +++ b/tests/test_kernel_factories.py @@ -15,8 +15,8 @@ from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.presets.baybe import ( BayBEKernelFactory, - BayBENumericalKernelFactory, - BayBETaskKernelFactory, + _BayBENumericalKernelFactory, + _BayBETaskKernelFactory, ) # A selector that accepts all parameters @@ -27,25 +27,25 @@ ("factory", "parameters", "error"), [ param( - BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL), + _BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL), [TaskParameter("task", ["t1", "t2"])], IncompatibleSearchSpaceError, id="regular_rejects_task", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [CategoricalParameter("cat", ["a", "b"])], IncompatibleSearchSpaceError, id="task_rejects_categorical", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [NumericalDiscreteParameter("num", [1, 2, 3])], IncompatibleSearchSpaceError, id="task_rejects_numerical_discrete", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [NumericalContinuousParameter("cont", (0, 1))], IncompatibleSearchSpaceError, id="task_rejects_numerical_continuous", From 0209de41f790d93d6d7661647a266f468e3c7a61 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 16:01:01 +0200 Subject: [PATCH 2/9] Scope inner factory to non-task parameters in transfer learning decorator --- .../gaussian_process/components/kernel.py | 32 ++++++++++++++++++- .../gaussian_process/presets/chen.py | 3 +- .../gaussian_process/presets/edbo.py | 11 ++++--- .../gaussian_process/presets/edbo_smoothed.py | 10 +++--- 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 8d6cec478b..817f91f3ee 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -80,6 +80,15 @@ def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | Non p.name for p in searchspace.parameters if self.parameter_selector(p) ) + def _get_effective_dimensionality(self, searchspace: SearchSpace) -> int: + """Get the number of computational columns for the selected parameters.""" + names = self.get_parameter_names(searchspace) + if names is None: + return len(searchspace.comp_rep_columns) + return sum( + len(searchspace.get_comp_rep_parameter_indices(name)) for name in names + ) + def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None: """Validate that the given parameters are supported by the factory. @@ -145,10 +154,31 @@ def _enable_transfer_learning( new_cls = type(cls.__name__, (cls,), {"__doc__": cls.__doc__}) original_call = cls.__call__ + original_supported_kinds = cls._supported_parameter_kinds + _task_exclude_selector = TypeSelector((TaskParameter,), exclude=True) @functools.wraps(original_call) def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): - base_kernel = original_call(self, searchspace, train_x, train_y) + # Temporarily narrow the supported parameter kinds to those of the original + # class. If the decorator logic is correct, the original factory should never + # see the extended scope, but this acts as a sanity check to prevent regressions + broadened_kinds = new_cls._supported_parameter_kinds + new_cls._supported_parameter_kinds = original_supported_kinds + + # Split off the task parameters + original_selector = self.parameter_selector + if original_selector is None: + self.parameter_selector = _task_exclude_selector + else: + self.parameter_selector = lambda p: ( + _task_exclude_selector(p) and original_selector(p) + ) + try: + base_kernel = original_call(self, searchspace, train_x, train_y) + finally: + new_cls._supported_parameter_kinds = broadened_kinds + self.parameter_selector = original_selector + if searchspace.task_idx is not None: icm = ICMKernelFactory(base_kernel_or_factory=base_kernel) return icm(searchspace, train_x, train_y) diff --git a/baybe/surrogates/gaussian_process/presets/chen.py b/baybe/surrogates/gaussian_process/presets/chen.py index 91dbc82f12..b130a3f3a7 100644 --- a/baybe/surrogates/gaussian_process/presets/chen.py +++ b/baybe/surrogates/gaussian_process/presets/chen.py @@ -40,7 +40,8 @@ class CHENKernelFactory(_PureKernelFactory): def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - lengthscale = 0.4 * math.sqrt(train_x.shape[-1]) + 4.0 + n_dimensions = self._get_effective_dimensionality(searchspace) + lengthscale = 0.4 * math.sqrt(n_dimensions) + 4.0 lengthscale_prior = GammaPrior(2.0 * lengthscale, 2.0) lengthscale_initial_value = lengthscale outputscale_prior = GammaPrior(1.0 * lengthscale, 1.0) diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 5569ea3630..b2870d4e46 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -11,8 +11,7 @@ from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters import TaskParameter -from baybe.parameters.enum import SubstanceEncoding +from baybe.parameters.enum import SubstanceEncoding, _ParameterKind from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete @@ -68,7 +67,7 @@ class EDBOKernelFactory(_PureKernelFactory): def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] + effective_dims = self._get_effective_dimensionality(searchspace) switching_condition = _contains_encoding( searchspace.discrete, _EDBO_ENCODINGS @@ -133,8 +132,10 @@ def __call__( import torch from gpytorch.likelihoods import GaussianLikelihood - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] + effective_dims = sum( + len(searchspace.get_comp_rep_parameter_indices(p.name)) + for p in searchspace.parameters + if p._kind & _ParameterKind.REGULAR ) switching_condition = _contains_encoding( diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index 785ecb868a..e735ed96ab 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -11,7 +11,7 @@ from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters import TaskParameter +from baybe.parameters.enum import _ParameterKind from baybe.priors.basic import GammaPrior from baybe.surrogates.gaussian_process.components.kernel import ( _enable_transfer_learning, @@ -45,7 +45,7 @@ class _SmoothedEDBONumericalKernelFactory(_PureKernelFactory): def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] + effective_dims = self._get_effective_dimensionality(searchspace) # Interpolate prior moments linearly between low D and high D regime. # The high D regime itself is the average of the EDBO OHE and Mordred regime. @@ -106,8 +106,10 @@ def __call__( # Interpolate prior moments linearly between low D and high D regime. # The high D regime itself is the average of the EDBO OHE and Mordred regime. # Values outside the dimension limits will get the border value assigned. - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] + effective_dims = sum( + len(searchspace.get_comp_rep_parameter_indices(p.name)) + for p in searchspace.parameters + if p._kind & _ParameterKind.REGULAR ) prior = GammaPrior( From 6b086793674380ecf9dfdec87f4c8beaacc520ac Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 16:44:34 +0200 Subject: [PATCH 3/9] Fix class name of dynamically created kernel factory `_enable_transfer_learning` now accepts an optional `name` parameter so that the dynamically created class can have the correct `__name__` when the function is called directly (rather than used as a decorator). This fixes serialization for `SmoothedEDBOKernelFactory`, which was previously serialized as `_SmoothedEDBONumericalKernelFactory`. --- baybe/surrogates/gaussian_process/components/kernel.py | 7 +++++-- baybe/surrogates/gaussian_process/presets/edbo_smoothed.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 817f91f3ee..b0fd698277 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -130,7 +130,7 @@ def _make( def _enable_transfer_learning( - cls: type[_PureKernelFactory], / + cls: type[_PureKernelFactory], name: str | None = None, / ) -> type[_PureKernelFactory]: """Class decorator enabling BayBE's default transfer learning mechanism. @@ -140,6 +140,9 @@ def _enable_transfer_learning( Args: cls: The kernel factory class to decorate. + name: Optional name for the created class. Defaults to ``cls.__name__``. + Useful when calling the function directly (as opposed to using it as a + decorator) and assigning the result to a different name. Raises: TypeError: If the factory already supports task parameters. @@ -151,7 +154,7 @@ def _enable_transfer_learning( raise TypeError(f"'{cls.__name__}' already supports task parameters.") # Create a subclass so the original class remains unmodified - new_cls = type(cls.__name__, (cls,), {"__doc__": cls.__doc__}) + new_cls = type(name or cls.__name__, (cls,), {"__doc__": cls.__doc__}) original_call = cls.__call__ original_supported_kinds = cls._supported_parameter_kinds diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index e735ed96ab..8915e63db5 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -74,7 +74,7 @@ def _make( SmoothedEDBOKernelFactory = _enable_transfer_learning( - _SmoothedEDBONumericalKernelFactory + _SmoothedEDBONumericalKernelFactory, "SmoothedEDBOKernelFactory" ) """A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`). From a713bfe8d9c09da4c110bca4fdb97bb378083bbf Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 17:41:55 +0200 Subject: [PATCH 4/9] Replace factory aliases with thin subclasses Simple aliases like `BayBEKernelFactory = SmoothedEDBOKernelFactory` cause the serialized type name to be that of the underlying class, which means the identity is lost on deserialization. Using thin subclasses ensures each factory has its own stable `__name__`. --- .../gaussian_process/presets/baybe.py | 20 ++++++++++--------- .../gaussian_process/presets/edbo.py | 4 ++-- .../gaussian_process/presets/edbo_smoothed.py | 5 +++-- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 0fb3ba281f..484c3152da 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -29,17 +29,17 @@ from torch import Tensor -_BayBENumericalKernelFactory = _SmoothedEDBONumericalKernelFactory -"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 +class _BayBENumericalKernelFactory(_SmoothedEDBONumericalKernelFactory): + """The default numerical kernel factory for GP surrogates.""" -BayBEKernelFactory = SmoothedEDBOKernelFactory -"""The default kernel factory for Gaussian process surrogates.""" +class BayBEKernelFactory(SmoothedEDBOKernelFactory): + """The default kernel factory for GP surrogates.""" @define class _BayBETaskKernelFactory(_PureKernelFactory): - """The factory providing the default task kernel for Gaussian process surrogates.""" + """The default task kernel factory for GP surrogates.""" _uses_parameter_names: ClassVar[bool] = True # See base class. @@ -64,11 +64,13 @@ def _make( ) -BayBEMeanFactory = LazyConstantMeanFactory -"""The factory providing the default mean function for Gaussian process surrogates.""" +class BayBEMeanFactory(LazyConstantMeanFactory): + """The default mean factory for GP surrogates.""" + + +class BayBELikelihoodFactory(SmoothedEDBOLikelihoodFactory): + """The default likelihood factory for GP surrogates.""" -BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory -"""The factory providing the default likelihood for Gaussian process surrogates.""" # Aliases for generic preset imports PresetKernelFactory = BayBEKernelFactory diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index b2870d4e46..7a3f569220 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -113,8 +113,8 @@ def _make( ) -EDBOMeanFactory = LazyConstantMeanFactory -"""A factory providing mean functions for the EDBO preset.""" +class EDBOMeanFactory(LazyConstantMeanFactory): + """A factory providing mean functions for the EDBO preset.""" @define diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index 8915e63db5..59133f8cc9 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -83,8 +83,9 @@ def _make( and interpolates the prior moments linearly in between. """ # noqa: E501 -SmoothedEDBOMeanFactory = LazyConstantMeanFactory -"""A factory providing mean functions for the smoothed EDBO preset.""" + +class SmoothedEDBOMeanFactory(LazyConstantMeanFactory): + """A factory providing mean functions for the smoothed EDBO preset.""" @define From 1358b3a3dcefb28c64093b93ba2dd1ddc62ac4a7 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 18:42:48 +0200 Subject: [PATCH 5/9] Add serialization roundtrip tests for kernel factories --- tests/test_kernel_factories.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_kernel_factories.py b/tests/test_kernel_factories.py index 45f3ccc109..6c87e63c9f 100644 --- a/tests/test_kernel_factories.py +++ b/tests/test_kernel_factories.py @@ -13,11 +13,17 @@ NumericalDiscreteParameter, ) from baybe.searchspace.core import SearchSpace +from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate from baybe.surrogates.gaussian_process.presets.baybe import ( BayBEKernelFactory, _BayBENumericalKernelFactory, _BayBETaskKernelFactory, ) +from baybe.surrogates.gaussian_process.presets.chen import CHENKernelFactory +from baybe.surrogates.gaussian_process.presets.edbo import EDBOKernelFactory +from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( + SmoothedEDBOKernelFactory, +) # A selector that accepts all parameters _SELECT_ALL = lambda parameter: True # noqa: E731 @@ -73,3 +79,21 @@ def test_factory_parameter_kind_validation(factory, parameters, error): else pytest.raises(error, match="does not support") ): factory(ss, train_x, train_y) + + +@pytest.mark.parametrize( + "factory", + [ + param(BayBEKernelFactory(), id="BayBEKernelFactory"), + param(SmoothedEDBOKernelFactory(), id="SmoothedEDBOKernelFactory"), + param(EDBOKernelFactory(), id="EDBOKernelFactory"), + param(CHENKernelFactory(), id="CHENKernelFactory"), + ], +) +def test_kernel_factory_serialization_roundtrip(factory): + """Kernel factories survive a serialization roundtrip via a GP surrogate.""" + gp = GaussianProcessSurrogate(kernel_or_factory=factory) + json_str = gp.to_json() + gp_roundtrip = GaussianProcessSurrogate.from_json(json_str) + assert type(gp.kernel_factory) is type(gp_roundtrip.kernel_factory) + assert gp == gp_roundtrip From a0fe0e31e02dbb138096574d5fe7ae979be7f640 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 20:09:43 +0200 Subject: [PATCH 6/9] Fix serialization of transfer-learning-decorated kernel factories When used as a decorator (@_enable_transfer_learning), modify the class in-place instead of creating a subclass with the same __name__. The previous approach left two concrete classes with identical names in the subclass registry, causing find_subclass to resolve to the @define- processed intermediate (without the TL wrapper) during deserialization. When called with an explicit name argument (for cases like SmoothedEDBOKernelFactory where the original class is reused elsewhere), the subclass approach is preserved since the distinct name avoids any collision. --- .../gaussian_process/components/kernel.py | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index b0fd698277..29012903f5 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -138,11 +138,15 @@ def _enable_transfer_learning( automatically composes its kernel with BayBE's default task kernel. Otherwise, the factory behaves unchanged. + When used as a decorator (without ``name``), the class is modified in-place. + When called with a ``name`` argument, a new subclass is created so that the + original class remains unmodified. The latter form is intended for cases where + the original class is reused independently elsewhere. + Args: cls: The kernel factory class to decorate. - name: Optional name for the created class. Defaults to ``cls.__name__``. - Useful when calling the function directly (as opposed to using it as a - decorator) and assigning the result to a different name. + name: Optional name for the created class. If provided, a new subclass is + created instead of modifying ``cls`` in-place. Raises: TypeError: If the factory already supports task parameters. @@ -153,8 +157,14 @@ def _enable_transfer_learning( if cls._supported_parameter_kinds & _ParameterKind.TASK: raise TypeError(f"'{cls.__name__}' already supports task parameters.") - # Create a subclass so the original class remains unmodified - new_cls = type(name or cls.__name__, (cls,), {"__doc__": cls.__doc__}) + # This distinction is important for serialization so that the classes can be + # correctly identified by their names in the subclass registry + if name is not None: + # Create a subclass so the original class remains unmodified + target_cls = type(name, (cls,), {"__doc__": cls.__doc__}) + else: + # Modify the class in-place (avoids name collision in subclass registry) + target_cls = cls original_call = cls.__call__ original_supported_kinds = cls._supported_parameter_kinds @@ -165,8 +175,8 @@ def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): # Temporarily narrow the supported parameter kinds to those of the original # class. If the decorator logic is correct, the original factory should never # see the extended scope, but this acts as a sanity check to prevent regressions - broadened_kinds = new_cls._supported_parameter_kinds - new_cls._supported_parameter_kinds = original_supported_kinds + broadened_kinds = target_cls._supported_parameter_kinds + target_cls._supported_parameter_kinds = original_supported_kinds # Split off the task parameters original_selector = self.parameter_selector @@ -179,7 +189,7 @@ def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): try: base_kernel = original_call(self, searchspace, train_x, train_y) finally: - new_cls._supported_parameter_kinds = broadened_kinds + target_cls._supported_parameter_kinds = broadened_kinds self.parameter_selector = original_selector if searchspace.task_idx is not None: @@ -187,11 +197,11 @@ def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): return icm(searchspace, train_x, train_y) return base_kernel - new_cls.__call__ = __call__ # type: ignore[method-assign] - new_cls._supported_parameter_kinds = ( + target_cls.__call__ = __call__ # type: ignore[method-assign] + target_cls._supported_parameter_kinds = ( cls._supported_parameter_kinds | _ParameterKind.TASK ) - return new_cls + return target_cls @define From 3b719fb27597a84cc32dbdc56bcdd60aaabdcc8e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 20:16:05 +0200 Subject: [PATCH 7/9] Add SerialMixin to _PureKernelFactory and move serialization test --- .../gaussian_process/components/kernel.py | 3 ++- .../test_kernel_factory_serialization.py | 23 ++++++++++++++++++ tests/test_kernel_factories.py | 24 ------------------- 3 files changed, 25 insertions(+), 25 deletions(-) create mode 100644 tests/serialization/test_kernel_factory_serialization.py diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 29012903f5..0b951e5e5c 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -23,6 +23,7 @@ to_parameter_selector, ) from baybe.searchspace.core import SearchSpace +from baybe.serialization.mixin import SerialMixin from baybe.surrogates.gaussian_process.components.generic import ( GPComponentFactoryProtocol, GPComponentType, @@ -45,7 +46,7 @@ @define -class _PureKernelFactory(KernelFactoryProtocol, ABC): +class _PureKernelFactory(KernelFactoryProtocol, SerialMixin, ABC): """Base class for pure kernel factories.""" # For internal use only: sanity check mechanism to remind developers of new diff --git a/tests/serialization/test_kernel_factory_serialization.py b/tests/serialization/test_kernel_factory_serialization.py new file mode 100644 index 0000000000..0bdc20af93 --- /dev/null +++ b/tests/serialization/test_kernel_factory_serialization.py @@ -0,0 +1,23 @@ +"""Kernel factory serialization tests.""" + +import pytest + +from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory +from baybe.surrogates.gaussian_process.presets import * # noqa: F401, F403 +from baybe.utils.basic import get_subclasses +from tests.serialization.utils import assert_roundtrip_consistency + +_KERNEL_FACTORIES = [ + cls + for cls in get_subclasses(_PureKernelFactory) + if not cls.__name__.startswith("_") +] + + +@pytest.mark.parametrize( + "factory", + [pytest.param(cls(), id=cls.__name__) for cls in _KERNEL_FACTORIES], +) +def test_roundtrip(factory): + """A serialization roundtrip yields an equivalent object.""" + assert_roundtrip_consistency(factory) diff --git a/tests/test_kernel_factories.py b/tests/test_kernel_factories.py index 6c87e63c9f..45f3ccc109 100644 --- a/tests/test_kernel_factories.py +++ b/tests/test_kernel_factories.py @@ -13,17 +13,11 @@ NumericalDiscreteParameter, ) from baybe.searchspace.core import SearchSpace -from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate from baybe.surrogates.gaussian_process.presets.baybe import ( BayBEKernelFactory, _BayBENumericalKernelFactory, _BayBETaskKernelFactory, ) -from baybe.surrogates.gaussian_process.presets.chen import CHENKernelFactory -from baybe.surrogates.gaussian_process.presets.edbo import EDBOKernelFactory -from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( - SmoothedEDBOKernelFactory, -) # A selector that accepts all parameters _SELECT_ALL = lambda parameter: True # noqa: E731 @@ -79,21 +73,3 @@ def test_factory_parameter_kind_validation(factory, parameters, error): else pytest.raises(error, match="does not support") ): factory(ss, train_x, train_y) - - -@pytest.mark.parametrize( - "factory", - [ - param(BayBEKernelFactory(), id="BayBEKernelFactory"), - param(SmoothedEDBOKernelFactory(), id="SmoothedEDBOKernelFactory"), - param(EDBOKernelFactory(), id="EDBOKernelFactory"), - param(CHENKernelFactory(), id="CHENKernelFactory"), - ], -) -def test_kernel_factory_serialization_roundtrip(factory): - """Kernel factories survive a serialization roundtrip via a GP surrogate.""" - gp = GaussianProcessSurrogate(kernel_or_factory=factory) - json_str = gp.to_json() - gp_roundtrip = GaussianProcessSurrogate.from_json(json_str) - assert type(gp.kernel_factory) is type(gp_roundtrip.kernel_factory) - assert gp == gp_roundtrip From fc6268c1da4d4339a98abf983f00d1e8a23e3e88 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 20:29:35 +0200 Subject: [PATCH 8/9] Fix __module__ on dynamically-created kernel factory subclass The Protocol metaclass (_ProtocolMeta) defaults __module__ to 'abc' when creating classes via 3-arg type(). Set it explicitly from the parent class so that SmoothedEDBOKernelFactory correctly reports its module as baybe.surrogates.gaussian_process.presets.edbo_smoothed. --- baybe/surrogates/gaussian_process/components/kernel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 0b951e5e5c..9e8ceb1f56 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -161,8 +161,12 @@ def _enable_transfer_learning( # This distinction is important for serialization so that the classes can be # correctly identified by their names in the subclass registry if name is not None: - # Create a subclass so the original class remains unmodified - target_cls = type(name, (cls,), {"__doc__": cls.__doc__}) + # Create a subclass so the original class remains unmodified. + # __module__ must be set explicitly because the Protocol metaclass + # would otherwise default it to "abc". + target_cls = type( + name, (cls,), {"__doc__": cls.__doc__, "__module__": cls.__module__} + ) else: # Modify the class in-place (avoids name collision in subclass registry) target_cls = cls From 5229f7f32593c8555a5f58e824d00ac8f9e6ca58 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 7 May 2026 20:35:30 +0200 Subject: [PATCH 9/9] Suppress mypy errors from dynamic class creation in _enable_transfer_learning --- baybe/surrogates/gaussian_process/components/kernel.py | 8 ++++---- baybe/surrogates/gaussian_process/presets/baybe.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 9e8ceb1f56..6b2a52eb97 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -180,8 +180,8 @@ def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): # Temporarily narrow the supported parameter kinds to those of the original # class. If the decorator logic is correct, the original factory should never # see the extended scope, but this acts as a sanity check to prevent regressions - broadened_kinds = target_cls._supported_parameter_kinds - target_cls._supported_parameter_kinds = original_supported_kinds + broadened_kinds = target_cls._supported_parameter_kinds # type: ignore[attr-defined] + target_cls._supported_parameter_kinds = original_supported_kinds # type: ignore[attr-defined] # Split off the task parameters original_selector = self.parameter_selector @@ -194,7 +194,7 @@ def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): try: base_kernel = original_call(self, searchspace, train_x, train_y) finally: - target_cls._supported_parameter_kinds = broadened_kinds + target_cls._supported_parameter_kinds = broadened_kinds # type: ignore[attr-defined] self.parameter_selector = original_selector if searchspace.task_idx is not None: @@ -203,7 +203,7 @@ def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): return base_kernel target_cls.__call__ = __call__ # type: ignore[method-assign] - target_cls._supported_parameter_kinds = ( + target_cls._supported_parameter_kinds = ( # type: ignore[attr-defined] cls._supported_parameter_kinds | _ParameterKind.TASK ) return target_cls diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 484c3152da..f5cfe65941 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -33,7 +33,7 @@ class _BayBENumericalKernelFactory(_SmoothedEDBONumericalKernelFactory): """The default numerical kernel factory for GP surrogates.""" -class BayBEKernelFactory(SmoothedEDBOKernelFactory): +class BayBEKernelFactory(SmoothedEDBOKernelFactory): # type: ignore[valid-type, misc] """The default kernel factory for GP surrogates."""