Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Comment thread
Scienfitz marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories
- `parameter_names` attribute to basic kernels for controlling the considered parameters
- `ParameterKind` flag enum for classifying parameters by their role and automatic
parameter kind validation in kernel factories
- `IndexKernel` and `PositiveIndexKernel` classes
- Interpoint constraints for continuous search spaces
- `IndexKernel` and `PositiveIndexKernel` classes
Expand Down
8 changes: 8 additions & 0 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from baybe.utils.metadata import MeasurableMetadata, to_metadata

if TYPE_CHECKING:
from baybe.parameters.enum import _ParameterKind
from baybe.searchspace.continuous import SubspaceContinuous
from baybe.searchspace.core import SearchSpace
from baybe.searchspace.discrete import SubspaceDiscrete
Expand Down Expand Up @@ -77,6 +78,13 @@ def is_discrete(self) -> bool:
"""Boolean indicating if this is a discrete parameter."""
return isinstance(self, DiscreteParameter)

@property
def _kind(self) -> _ParameterKind:
"""The kind of the parameter."""
from baybe.parameters.enum import _ParameterKind

return _ParameterKind.from_parameter(self)

@property
@abstractmethod
def comp_rep_columns(self) -> tuple[str, ...]:
Expand Down
34 changes: 33 additions & 1 deletion baybe/parameters/enum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,38 @@
"""Parameter-related enumerations."""

from enum import Enum
from __future__ import annotations

from enum import Enum, Flag, auto
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from baybe.parameters.base import Parameter


class _ParameterKind(Flag):
"""Flag enum encoding the kind of a parameter.

Can be used to express compatibility (e.g. Gaussian process kernel factories)
with different parameter types via bitwise combination of flags.
Comment thread
AVHopp marked this conversation as resolved.
"""

REGULAR = auto()
"""Regular parameter undergoing no special treatment."""

TASK = auto()
"""Task parameter for transfer learning."""

FIDELITY = auto()
"""Fidelity parameter for multi-fidelity modelling."""

@staticmethod
def from_parameter(parameter: Parameter) -> _ParameterKind:
"""Determine the kind of a parameter from its type."""
from baybe.parameters.categorical import TaskParameter

if isinstance(parameter, TaskParameter):
return _ParameterKind.TASK
return _ParameterKind.REGULAR


class ParameterEncoding(Enum):
Expand Down
38 changes: 1 addition & 37 deletions baybe/parameters/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Collection
from typing import ClassVar, Protocol
from typing import Protocol

from attrs import Converter, define, field
from attrs.converters import optional
from attrs.validators import deep_iterable, instance_of, min_len
from typing_extensions import override

from baybe.parameters.base import Parameter
from baybe.searchspace.core import SearchSpace
from baybe.utils.basic import to_tuple
from baybe.utils.conversion import nonstring_to_tuple

Expand Down Expand Up @@ -131,37 +129,3 @@ def to_parameter_selector(
return TypeSelector(items)

raise TypeError(f"Cannot convert {x!r} to a parameter selector.")


@define
class _ParameterSelectorMixin:
"""A mixin class to enable parameter selection."""

# For internal use only: sanity check mechanism to remind developers of new
# subclasses to actually use the parameter selector when it is provided
# TODO: Perhaps we can find a more elegant way to enforce this by design
_uses_parameter_names: ClassVar[bool] = False

parameter_selector: ParameterSelectorProtocol | None = field(
default=None, converter=optional(to_parameter_selector), kw_only=True
)
"""An optional selector to specify which parameters are to be considered."""

def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None:
"""Get the names of the parameters to be considered."""
if self.parameter_selector is None:
return None

return tuple(
p.name for p in searchspace.parameters if self.parameter_selector(p)
)

def __attrs_post_init__(self):
if self.parameter_selector is not None and not self._uses_parameter_names:
raise AssertionError(
f"A `parameter_selector` was provided to "
f"`{type(self).__name__}`, but the class does not set "
f"`_uses_parameter_names = True`. Subclasses that accept a "
f"parameter selector must explicitly set this flag to confirm "
f"they actually use the selected parameter names."
)
107 changes: 103 additions & 4 deletions baybe/surrogates/gaussian_process/components/kernel.py
Comment thread
Scienfitz marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from attrs.converters import optional
from attrs.validators import is_callable
from typing_extensions import override

from baybe.exceptions import IncompatibleSearchSpaceError
from baybe.kernels.base import Kernel
from baybe.kernels.composite import ProductKernel
from baybe.parameters.categorical import TaskParameter
from baybe.parameters.enum import _ParameterKind
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components.generic import (
Expand All @@ -27,6 +33,8 @@
from gpytorch.kernels import Kernel as GPyTorchKernel
from torch import Tensor

from baybe.parameters.base import Parameter

KernelFactoryProtocol = GPComponentFactoryProtocol[Kernel | GPyTorchKernel]
PlainKernelFactory = PlainGPComponentFactory[Kernel | GPyTorchKernel]
else:
Expand All @@ -36,7 +44,94 @@


@define
class ICMKernelFactory(KernelFactoryProtocol):
class _PureKernelFactory(KernelFactoryProtocol, ABC):
"""Base class for pure kernel factories."""

# For internal use only: sanity check mechanism to remind developers of new
# factories to actually use the parameter selector when it is provided
# TODO: Perhaps we can find a more elegant way to enforce this by design
_uses_parameter_names: ClassVar[bool] = False

_supported_parameter_kinds: ClassVar[_ParameterKind] = _ParameterKind.REGULAR
"""The parameter kinds supported by the kernel factory."""

parameter_selector: ParameterSelectorProtocol | None = field(
default=None, converter=optional(to_parameter_selector)
)
"""An optional selector to specify which parameters are considered by the kernel."""

def __attrs_post_init__(self):
if self.parameter_selector is not None and not self._uses_parameter_names:
raise AssertionError(
f"A `parameter_selector` was provided to "
f"`{type(self).__name__}`, but the class does not set "
f"`_uses_parameter_names = True`. Subclasses that accept a "
f"parameter selector must explicitly set this flag to confirm "
f"they actually use the selected parameter names."
)

def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None:
"""Get the names of the parameters to be considered by the kernel."""
if self.parameter_selector is None:
return None

return tuple(
p.name for p in searchspace.parameters if self.parameter_selector(p)
)

def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None:
"""Validate that the given parameters are supported by the factory.

Args:
parameters: The parameters to validate.

Raises:
IncompatibleSearchSpaceError: If unsupported parameter kinds are found.
"""
if unsupported := [
p.name
for p in parameters
if not (p._kind & self._supported_parameter_kinds)
]:
raise IncompatibleSearchSpaceError(
f"'{type(self).__name__}' does not support parameter kind(s) for "
f"parameter(s) {unsupported}. Supported kinds: "
f"{self._supported_parameter_kinds}."
)

@override
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
"""Construct the kernel, validating parameter kinds before construction."""
if self.parameter_selector is not None:
params = [p for p in searchspace.parameters if self.parameter_selector(p)]
else:
params = list(searchspace.parameters)
self._validate_parameter_kinds(params)

return self._make(searchspace, train_x, train_y)

@abstractmethod
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
"""Construct the kernel."""


@define
class _MetaKernelFactory(KernelFactoryProtocol, ABC):
"""Base class for meta kernel factories that orchestrate other kernel factories."""

@override
@abstractmethod
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel: ...


@define
class ICMKernelFactory(_MetaKernelFactory):
"""A kernel factory that constructs an ICM kernel for transfer learning.

ICM: Intrinsic Coregionalization Model :cite:p:`NIPS2007_66368270`
Expand Down Expand Up @@ -78,4 +173,8 @@ def __call__(
) -> Kernel:
base_kernel = self.base_kernel_factory(searchspace, train_x, train_y)
task_kernel = self.task_kernel_factory(searchspace, train_x, train_y)
return ProductKernel([base_kernel, task_kernel])
if isinstance(base_kernel, Kernel):
base_kernel = base_kernel.to_gpytorch(searchspace)
if isinstance(task_kernel, Kernel):
task_kernel = task_kernel.to_gpytorch(searchspace)
return base_kernel * task_kernel
Comment thread
AdrianSosic marked this conversation as resolved.
20 changes: 14 additions & 6 deletions baybe/surrogates/gaussian_process/presets/baybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from baybe.kernels.base import Kernel
from baybe.kernels.basic import IndexKernel
from baybe.parameters.categorical import TaskParameter
from baybe.parameters.enum import _ParameterKind
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
_ParameterSelectorMixin,
to_parameter_selector,
)
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components.kernel import KernelFactoryProtocol
from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory
from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
SmoothedEDBOKernelFactory,
Expand All @@ -29,11 +29,16 @@


@define
class BayBEKernelFactory(KernelFactoryProtocol):
class BayBEKernelFactory(_PureKernelFactory):
"""The default kernel factory for Gaussian process surrogates."""

_supported_parameter_kinds: ClassVar[_ParameterKind] = (
_ParameterKind.REGULAR | _ParameterKind.TASK
)
# See base class.

@override
def __call__(
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory
Expand All @@ -48,20 +53,23 @@ def __call__(


@define
class BayBETaskKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin):
class BayBETaskKernelFactory(_PureKernelFactory):
"""The factory providing the default task kernel for Gaussian process surrogates."""

_uses_parameter_names: ClassVar[bool] = True
# See base class.

_supported_parameter_kinds: ClassVar[_ParameterKind] = _ParameterKind.TASK
# See base class.

parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter]),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def __call__(
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
return IndexKernel(
Expand Down
13 changes: 6 additions & 7 deletions baybe/surrogates/gaussian_process/presets/edbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
_ParameterSelectorMixin,
to_parameter_selector,
)
from baybe.parameters.substance import SubstanceParameter
from baybe.priors.basic import GammaPrior
from baybe.searchspace.discrete import SubspaceDiscrete
from baybe.surrogates.gaussian_process.components.kernel import KernelFactoryProtocol
from baybe.surrogates.gaussian_process.components.kernel import (
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
LikelihoodFactoryProtocol,
)
Expand Down Expand Up @@ -56,7 +57,7 @@ def _contains_encoding(


@define
class EDBOKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin):
class EDBOKernelFactory(_PureKernelFactory):
"""A factory providing EDBO kernels.

References:
Expand All @@ -74,12 +75,10 @@ class EDBOKernelFactory(KernelFactoryProtocol, _ParameterSelectorMixin):
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def __call__(
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
effective_dims = train_x.shape[-1] - len(
[p for p in searchspace.parameters if isinstance(p, TaskParameter)]
)
effective_dims = train_x.shape[-1]

switching_condition = _contains_encoding(
searchspace.discrete, _EDBO_ENCODINGS
Expand Down
Loading
Loading