Skip to content
Open
57 changes: 53 additions & 4 deletions baybe/surrogates/gaussian_process/components/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import functools
from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import partial
Expand Down Expand Up @@ -119,6 +120,47 @@ def _make(
"""Construct the kernel."""


def _enable_transfer_learning(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Scienfitz: in principle ready and working. However, I have to admit that this was significantly more painful than anticipated, with many footguns along the way. So I'm open to a very harsh review and a complete change of direction, if you prefer and have an alternative/simpler idea.

But I hope that you get my intent for this: I think we need some mechanism that lets us say fill this preset with our default approach for a certain aspect that the preset does not specify, and the filling should be very much done without copying code since the BayBE defaults are expected to move. So we need something like a single source of truth. That said: maybe you have some smarter idea.

cls: type[_PureKernelFactory], /
) -> type[_PureKernelFactory]:
"""Class decorator enabling BayBE's default transfer learning mechanism.

When the search space contains a task parameter, the decorated factory
automatically composes its kernel with BayBE's default task kernel.
Otherwise, the factory behaves unchanged.

Args:
cls: The kernel factory class to decorate.

Raises:
TypeError: If the factory already supports task parameters.

Returns:
The decorated kernel factory class with transfer learning enabled.
"""
if cls._supported_parameter_kinds & _ParameterKind.TASK:
raise TypeError(f"'{cls.__name__}' already supports task parameters.")

# Create a subclass so the original class remains unmodified
new_cls = type(cls.__name__, (cls,), {"__doc__": cls.__doc__})

original_call = cls.__call__

@functools.wraps(original_call)
def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor):
base_kernel = original_call(self, searchspace, train_x, train_y)
if searchspace.task_idx is not None:
icm = ICMKernelFactory(base_kernel_or_factory=base_kernel)
return icm(searchspace, train_x, train_y)
return base_kernel
Comment thread
AdrianSosic marked this conversation as resolved.

new_cls.__call__ = __call__ # type: ignore[method-assign]
new_cls._supported_parameter_kinds = (
cls._supported_parameter_kinds | _ParameterKind.TASK
)
return new_cls


@define
class _MetaKernelFactory(KernelFactoryProtocol, ABC):
"""Base class for meta kernel factories that orchestrate other kernel factories."""
Expand Down Expand Up @@ -154,18 +196,25 @@ class ICMKernelFactory(_MetaKernelFactory):
@base_kernel_factory.default
def _default_base_kernel_factory(self) -> KernelFactoryProtocol:
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBENumericalKernelFactory,
_BayBENumericalKernelFactory,
)

return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True))
assert (
_BayBENumericalKernelFactory._supported_parameter_kinds
is _ParameterKind.REGULAR
)
return _BayBENumericalKernelFactory(
TypeSelector((TaskParameter,), exclude=True)
)

@task_kernel_factory.default
def _default_task_kernel_factory(self) -> KernelFactoryProtocol:
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBETaskKernelFactory,
_BayBETaskKernelFactory,
)

return BayBETaskKernelFactory(TypeSelector((TaskParameter,)))
assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK
return _BayBETaskKernelFactory()

@override
def __call__(
Expand Down
27 changes: 6 additions & 21 deletions baybe/surrogates/gaussian_process/presets/baybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,23 @@
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
SmoothedEDBOKernelFactory,
SmoothedEDBOLikelihoodFactory,
_SmoothedEDBONumericalKernelFactory,
)

if TYPE_CHECKING:
from torch import Tensor


@define
class BayBEKernelFactory(_PureKernelFactory):
"""The default kernel factory for Gaussian process surrogates."""

_supported_parameter_kinds: ClassVar[_ParameterKind] = (
_ParameterKind.REGULAR | _ParameterKind.TASK
)
# See base class.

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory

is_multitask = searchspace.task_idx is not None
factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory
return factory()(searchspace, train_x, train_y)
_BayBENumericalKernelFactory = _SmoothedEDBONumericalKernelFactory
"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501


BayBENumericalKernelFactory = SmoothedEDBOKernelFactory
"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501
BayBEKernelFactory = SmoothedEDBOKernelFactory
"""The default kernel factory for Gaussian process surrogates."""


Comment thread
AdrianSosic marked this conversation as resolved.
Outdated
@define
class BayBETaskKernelFactory(_PureKernelFactory):
class _BayBETaskKernelFactory(_PureKernelFactory):
"""The factory providing the default task kernel for Gaussian process surrogates."""

_uses_parameter_names: ClassVar[bool] = True
Expand Down
16 changes: 3 additions & 13 deletions baybe/surrogates/gaussian_process/presets/chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@
import math
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from attrs import define
from typing_extensions import override

from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters.categorical import TaskParameter
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.priors.basic import GammaPrior
from baybe.surrogates.gaussian_process.components.kernel import (
_enable_transfer_learning,
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
Expand All @@ -33,19 +28,14 @@
from baybe.searchspace.core import SearchSpace


@_enable_transfer_learning
@define
class CHENKernelFactory(_PureKernelFactory):
"""A factory providing adaptive hyperprior kernels as proposed by :cite:p:`Chen2026`.""" # noqa: E501

_uses_parameter_names: ClassVar[bool] = True
# See base class.
Comment thread
AdrianSosic marked this conversation as resolved.

parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter], exclude=True),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
Expand Down
15 changes: 3 additions & 12 deletions baybe/surrogates/gaussian_process/presets/edbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,18 @@
from collections.abc import Collection
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from attrs import define
from typing_extensions import override

from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters import TaskParameter
from baybe.parameters.enum import SubstanceEncoding
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.parameters.substance import SubstanceParameter
from baybe.priors.basic import GammaPrior
from baybe.searchspace.discrete import SubspaceDiscrete
from baybe.surrogates.gaussian_process.components.kernel import (
_enable_transfer_learning,
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
Expand Down Expand Up @@ -56,6 +52,7 @@ def _contains_encoding(
"""Encodings relevant to EDBO logic."""


@_enable_transfer_learning
@define
class EDBOKernelFactory(_PureKernelFactory):
"""A factory providing EDBO kernels, as proposed by :cite:p:`Shields2021`.
Expand All @@ -67,12 +64,6 @@ class EDBOKernelFactory(_PureKernelFactory):
_uses_parameter_names: ClassVar[bool] = True
# See base class.

Comment thread
AdrianSosic marked this conversation as resolved.
parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter], exclude=True),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
Expand Down
33 changes: 14 additions & 19 deletions baybe/surrogates/gaussian_process/presets/edbo_smoothed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@
from typing import TYPE_CHECKING, ClassVar

import numpy as np
from attrs import define, field
from attrs import define
from typing_extensions import override

from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters import TaskParameter
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.priors.basic import GammaPrior
from baybe.surrogates.gaussian_process.components.kernel import (
_enable_transfer_learning,
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
Expand All @@ -39,23 +35,12 @@


@define
class SmoothedEDBOKernelFactory(_PureKernelFactory):
"""A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`).

Takes the low and high dimensional limits of
:class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory`
and interpolates the prior moments linearly in between.
""" # noqa: E501
class _SmoothedEDBONumericalKernelFactory(_PureKernelFactory):
"""A factory providing the core numerical kernel for the smoothed EDBO preset."""

_uses_parameter_names: ClassVar[bool] = True
# See base class.

parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter], exclude=True),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
Expand Down Expand Up @@ -88,6 +73,16 @@ def _make(
)


SmoothedEDBOKernelFactory = _enable_transfer_learning(
_SmoothedEDBONumericalKernelFactory
)
"""A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`).

Takes the low and high dimensional limits of
:class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory`
and interpolates the prior moments linearly in between.
""" # noqa: E501

Comment thread
AdrianSosic marked this conversation as resolved.
SmoothedEDBOMeanFactory = LazyConstantMeanFactory
"""A factory providing mean functions for the smoothed EDBO preset."""

Expand Down
12 changes: 6 additions & 6 deletions tests/test_kernel_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBEKernelFactory,
BayBENumericalKernelFactory,
BayBETaskKernelFactory,
_BayBENumericalKernelFactory,
_BayBETaskKernelFactory,
)

# A selector that accepts all parameters
Expand All @@ -27,25 +27,25 @@
("factory", "parameters", "error"),
[
param(
BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL),
_BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL),
[TaskParameter("task", ["t1", "t2"])],
IncompatibleSearchSpaceError,
id="regular_rejects_task",
),
Comment thread
AdrianSosic marked this conversation as resolved.
param(
BayBETaskKernelFactory(parameter_selector=_SELECT_ALL),
_BayBETaskKernelFactory(parameter_selector=_SELECT_ALL),
[CategoricalParameter("cat", ["a", "b"])],
IncompatibleSearchSpaceError,
id="task_rejects_categorical",
),
param(
BayBETaskKernelFactory(parameter_selector=_SELECT_ALL),
_BayBETaskKernelFactory(parameter_selector=_SELECT_ALL),
[NumericalDiscreteParameter("num", [1, 2, 3])],
IncompatibleSearchSpaceError,
id="task_rejects_numerical_discrete",
),
param(
BayBETaskKernelFactory(parameter_selector=_SELECT_ALL),
_BayBETaskKernelFactory(parameter_selector=_SELECT_ALL),
[NumericalContinuousParameter("cont", (0, 1))],
IncompatibleSearchSpaceError,
id="task_rejects_numerical_continuous",
Expand Down