Skip to content
Open
107 changes: 102 additions & 5 deletions baybe/surrogates/gaussian_process/components/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import functools
from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import partial
Expand All @@ -22,6 +23,7 @@
to_parameter_selector,
)
from baybe.searchspace.core import SearchSpace
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.gaussian_process.components.generic import (
GPComponentFactoryProtocol,
GPComponentType,
Expand All @@ -44,7 +46,7 @@


@define
class _PureKernelFactory(KernelFactoryProtocol, ABC):
class _PureKernelFactory(KernelFactoryProtocol, SerialMixin, ABC):
"""Base class for pure kernel factories."""

# For internal use only: sanity check mechanism to remind developers of new
Expand Down Expand Up @@ -79,6 +81,15 @@ def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | Non
p.name for p in searchspace.parameters if self.parameter_selector(p)
)

def _get_effective_dimensionality(self, searchspace: SearchSpace) -> int:
"""Get the number of computational columns for the selected parameters."""
names = self.get_parameter_names(searchspace)
if names is None:
return len(searchspace.comp_rep_columns)
return sum(
len(searchspace.get_comp_rep_parameter_indices(name)) for name in names
)

def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None:
"""Validate that the given parameters are supported by the factory.

Expand Down Expand Up @@ -119,6 +130,85 @@ def _make(
"""Construct the kernel."""


def _enable_transfer_learning(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Scienfitz: in principle ready and working. However, I have to admit that this was significantly more painful than anticipated, with many footguns along the way. So I'm open to a very harsh review and a complete change of direction, if you prefer and have an alternative/simpler idea.

But I hope that you get my intent for this: I think we need some mechanism that lets us say fill this preset with our default approach for a certain aspect that the preset does not specify, and the filling should be very much done without copying code since the BayBE defaults are expected to move. So we need something like a single source of truth. That said: maybe you have some smarter idea.

cls: type[_PureKernelFactory], name: str | None = None, /
) -> type[_PureKernelFactory]:
"""Class decorator enabling BayBE's default transfer learning mechanism.

When the search space contains a task parameter, the decorated factory
automatically composes its kernel with BayBE's default task kernel.
Otherwise, the factory behaves unchanged.

When used as a decorator (without ``name``), the class is modified in-place.
When called with a ``name`` argument, a new subclass is created so that the
original class remains unmodified. The latter form is intended for cases where
the original class is reused independently elsewhere.

Args:
cls: The kernel factory class to decorate.
name: Optional name for the created class. If provided, a new subclass is
created instead of modifying ``cls`` in-place.

Raises:
TypeError: If the factory already supports task parameters.

Returns:
The decorated kernel factory class with transfer learning enabled.
"""
if cls._supported_parameter_kinds & _ParameterKind.TASK:
raise TypeError(f"'{cls.__name__}' already supports task parameters.")

# This distinction is important for serialization so that the classes can be
# correctly identified by their names in the subclass registry
if name is not None:
# Create a subclass so the original class remains unmodified.
# __module__ must be set explicitly because the Protocol metaclass
# would otherwise default it to "abc".
target_cls = type(
name, (cls,), {"__doc__": cls.__doc__, "__module__": cls.__module__}
)
else:
# Modify the class in-place (avoids name collision in subclass registry)
target_cls = cls

original_call = cls.__call__
original_supported_kinds = cls._supported_parameter_kinds
_task_exclude_selector = TypeSelector((TaskParameter,), exclude=True)

@functools.wraps(original_call)
def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor):
# Temporarily narrow the supported parameter kinds to those of the original
# class. If the decorator logic is correct, the original factory should never
# see the extended scope, but this acts as a sanity check to prevent regressions
broadened_kinds = target_cls._supported_parameter_kinds # type: ignore[attr-defined]
target_cls._supported_parameter_kinds = original_supported_kinds # type: ignore[attr-defined]

# Split off the task parameters
original_selector = self.parameter_selector
if original_selector is None:
self.parameter_selector = _task_exclude_selector
else:
self.parameter_selector = lambda p: (
_task_exclude_selector(p) and original_selector(p)
)
try:
base_kernel = original_call(self, searchspace, train_x, train_y)
finally:
target_cls._supported_parameter_kinds = broadened_kinds # type: ignore[attr-defined]
self.parameter_selector = original_selector

if searchspace.task_idx is not None:
icm = ICMKernelFactory(base_kernel_or_factory=base_kernel)
return icm(searchspace, train_x, train_y)
return base_kernel

target_cls.__call__ = __call__ # type: ignore[method-assign]
target_cls._supported_parameter_kinds = ( # type: ignore[attr-defined]
cls._supported_parameter_kinds | _ParameterKind.TASK
)
return target_cls


@define
class _MetaKernelFactory(KernelFactoryProtocol, ABC):
"""Base class for meta kernel factories that orchestrate other kernel factories."""
Expand Down Expand Up @@ -154,18 +244,25 @@ class ICMKernelFactory(_MetaKernelFactory):
@base_kernel_factory.default
def _default_base_kernel_factory(self) -> KernelFactoryProtocol:
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBENumericalKernelFactory,
_BayBENumericalKernelFactory,
)

return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True))
assert (
_BayBENumericalKernelFactory._supported_parameter_kinds
is _ParameterKind.REGULAR
)
return _BayBENumericalKernelFactory(
TypeSelector((TaskParameter,), exclude=True)
)

@task_kernel_factory.default
def _default_task_kernel_factory(self) -> KernelFactoryProtocol:
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBETaskKernelFactory,
_BayBETaskKernelFactory,
)

return BayBETaskKernelFactory(TypeSelector((TaskParameter,)))
assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK
return _BayBETaskKernelFactory()

@override
def __call__(
Expand Down
39 changes: 13 additions & 26 deletions baybe/surrogates/gaussian_process/presets/baybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,24 @@
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
SmoothedEDBOKernelFactory,
SmoothedEDBOLikelihoodFactory,
_SmoothedEDBONumericalKernelFactory,
)

if TYPE_CHECKING:
from torch import Tensor


@define
class BayBEKernelFactory(_PureKernelFactory):
"""The default kernel factory for Gaussian process surrogates."""

_supported_parameter_kinds: ClassVar[_ParameterKind] = (
_ParameterKind.REGULAR | _ParameterKind.TASK
)
# See base class.

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory
class _BayBENumericalKernelFactory(_SmoothedEDBONumericalKernelFactory):
"""The default numerical kernel factory for GP surrogates."""

is_multitask = searchspace.task_idx is not None
factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory
return factory()(searchspace, train_x, train_y)


BayBENumericalKernelFactory = SmoothedEDBOKernelFactory
"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501
class BayBEKernelFactory(SmoothedEDBOKernelFactory): # type: ignore[valid-type, misc]
"""The default kernel factory for GP surrogates."""


@define
class BayBETaskKernelFactory(_PureKernelFactory):
"""The factory providing the default task kernel for Gaussian process surrogates."""
class _BayBETaskKernelFactory(_PureKernelFactory):
"""The default task kernel factory for GP surrogates."""

_uses_parameter_names: ClassVar[bool] = True
# See base class.
Expand All @@ -79,11 +64,13 @@ def _make(
)


BayBEMeanFactory = LazyConstantMeanFactory
"""The factory providing the default mean function for Gaussian process surrogates."""
class BayBEMeanFactory(LazyConstantMeanFactory):
"""The default mean factory for GP surrogates."""


class BayBELikelihoodFactory(SmoothedEDBOLikelihoodFactory):
"""The default likelihood factory for GP surrogates."""

BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory
"""The factory providing the default likelihood for Gaussian process surrogates."""

# Aliases for generic preset imports
PresetKernelFactory = BayBEKernelFactory
Expand Down
19 changes: 5 additions & 14 deletions baybe/surrogates/gaussian_process/presets/chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@
import math
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from attrs import define
from typing_extensions import override

from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters.categorical import TaskParameter
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.priors.basic import GammaPrior
from baybe.surrogates.gaussian_process.components.kernel import (
_enable_transfer_learning,
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
Expand All @@ -33,24 +28,20 @@
from baybe.searchspace.core import SearchSpace


@_enable_transfer_learning
@define
class CHENKernelFactory(_PureKernelFactory):
"""A factory providing adaptive hyperprior kernels as proposed by :cite:p:`Chen2026`.""" # noqa: E501

_uses_parameter_names: ClassVar[bool] = True
# See base class.
Comment thread
AdrianSosic marked this conversation as resolved.

parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter], exclude=True),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
lengthscale = 0.4 * math.sqrt(train_x.shape[-1]) + 4.0
n_dimensions = self._get_effective_dimensionality(searchspace)
lengthscale = 0.4 * math.sqrt(n_dimensions) + 4.0
lengthscale_prior = GammaPrior(2.0 * lengthscale, 2.0)
lengthscale_initial_value = lengthscale
outputscale_prior = GammaPrior(1.0 * lengthscale, 1.0)
Expand Down
30 changes: 11 additions & 19 deletions baybe/surrogates/gaussian_process/presets/edbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,17 @@
from collections.abc import Collection
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from attrs import define
from typing_extensions import override

from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters import TaskParameter
from baybe.parameters.enum import SubstanceEncoding
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.parameters.enum import SubstanceEncoding, _ParameterKind
from baybe.parameters.substance import SubstanceParameter
from baybe.priors.basic import GammaPrior
from baybe.searchspace.discrete import SubspaceDiscrete
from baybe.surrogates.gaussian_process.components.kernel import (
_enable_transfer_learning,
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
Expand Down Expand Up @@ -56,6 +51,7 @@ def _contains_encoding(
"""Encodings relevant to EDBO logic."""


@_enable_transfer_learning
@define
class EDBOKernelFactory(_PureKernelFactory):
"""A factory providing EDBO kernels, as proposed by :cite:p:`Shields2021`.
Expand All @@ -67,17 +63,11 @@ class EDBOKernelFactory(_PureKernelFactory):
_uses_parameter_names: ClassVar[bool] = True
# See base class.

Comment thread
AdrianSosic marked this conversation as resolved.
parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter], exclude=True),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
effective_dims = train_x.shape[-1]
effective_dims = self._get_effective_dimensionality(searchspace)

switching_condition = _contains_encoding(
searchspace.discrete, _EDBO_ENCODINGS
Expand Down Expand Up @@ -123,8 +113,8 @@ def _make(
)


EDBOMeanFactory = LazyConstantMeanFactory
"""A factory providing mean functions for the EDBO preset."""
class EDBOMeanFactory(LazyConstantMeanFactory):
"""A factory providing mean functions for the EDBO preset."""


@define
Expand All @@ -142,8 +132,10 @@ def __call__(
import torch
from gpytorch.likelihoods import GaussianLikelihood

effective_dims = train_x.shape[-1] - len(
[p for p in searchspace.parameters if isinstance(p, TaskParameter)]
effective_dims = sum(
len(searchspace.get_comp_rep_parameter_indices(p.name))
for p in searchspace.parameters
if p._kind & _ParameterKind.REGULAR
)

switching_condition = _contains_encoding(
Expand Down
Loading
Loading