Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Gaussian process component factories
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
components, enabling full low-level customization
- Configurable fitting criterion for Gaussian process hyperparameter optimization
- Factories for all Gaussian process components
- `CHEN`, `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories
Expand Down
9 changes: 9 additions & 0 deletions baybe/surrogates/gaussian_process/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Gaussian process surrogate components."""

from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
FitCriterionFactoryProtocol,
PlainFitCriterionFactory,
)
from baybe.surrogates.gaussian_process.components.kernel import (
KernelFactoryProtocol,
PlainKernelFactory,
Expand All @@ -15,6 +20,10 @@
)

__all__ = [
# Fit Criterion
"FitCriterion",
"FitCriterionFactoryProtocol",
"PlainFitCriterionFactory",
# Kernel
"KernelFactoryProtocol",
"PlainKernelFactory",
Expand Down
46 changes: 46 additions & 0 deletions baybe/surrogates/gaussian_process/components/fit_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Fitting criteria for the Gaussian process surrogate."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
from gpytorch.mlls import MarginalLogLikelihood
from gpytorch.models import GP as GPyTorchModel


class FitCriterion(Enum):
"""Available fitting criteria for GP hyperparameter optimization."""

MARGINAL_LOG_LIKELIHOOD = "MARGINAL_LOG_LIKELIHOOD"
Comment thread
AdrianSosic marked this conversation as resolved.
"""Exact marginal log-likelihood."""

LEAVE_ONE_OUT_PSEUDOLIKELIHOOD = "LEAVE_ONE_OUT_PSEUDOLIKELIHOOD"
"""Leave-one-out cross-validation pseudo-likelihood."""

def to_gpytorch(
self, likelihood: GPyTorchLikelihood, model: GPyTorchModel
) -> MarginalLogLikelihood:
"""Create the corresponding GPyTorch MLL object."""
import gpytorch

mll_class = {
FitCriterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood,
FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD: gpytorch.mlls.LeaveOneOutPseudoLikelihood, # noqa: E501
}[self]
return mll_class(likelihood, model)


# Delayed import to avoid circular dependency
from baybe.surrogates.gaussian_process.components.generic import ( # noqa: E402
GPComponentFactoryProtocol,
PlainGPComponentFactory,
)

FitCriterionFactoryProtocol = GPComponentFactoryProtocol[FitCriterion]
"""A protocol defining the interface for fit criterion factories."""

PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion]
"""A trivial factory that returns a fixed fit criterion."""
16 changes: 13 additions & 3 deletions baybe/surrogates/gaussian_process/components/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from baybe.searchspace import SearchSpace
from baybe.serialization.core import block_serialization_hook, converter
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion

BayBEGPComponent: TypeAlias = Kernel
BayBEGPComponent: TypeAlias = Kernel | FitCriterion

if TYPE_CHECKING:
from gpytorch.kernels import Kernel as GPyTorchKernel
Expand Down Expand Up @@ -44,15 +45,24 @@ class GPComponentType(Enum):
LIKELIHOOD = "LIKELIHOOD"
"""Gaussian process likelihood."""

CRITERION = "CRITERION"
"""Gaussian process fitting criterion."""

def get_types(self) -> tuple[type, ...]:
"""Get the accepted BayBE and GPyTorch types for this component."""
types = []
types: list[type[GPComponent]] = []

# Add BayBE type if applicable
if self is GPComponentType.KERNEL:
from baybe.kernels.base import Kernel

types.append(Kernel)
elif self is GPComponentType.CRITERION:
from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
)

types.append(FitCriterion)

# Add GPyTorch type if available
if sys.modules.get("gpytorch") is not None:
Expand Down Expand Up @@ -85,7 +95,7 @@ def _is_gpytorch_component_class(obj: Any, /) -> bool:

def _validate_component(instance: Any, attribute: Attribute, value: Any) -> None:
"""Validate that an object is a BayBE or a GPyTorch GP component."""
if isinstance(value, Kernel) or _is_gpytorch_component_class(type(value)):
if isinstance(value, BayBEGPComponent) or _is_gpytorch_component_class(type(value)):
return

raise TypeError(
Expand Down
54 changes: 36 additions & 18 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from baybe.parameters.categorical import TaskParameter
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.base import Surrogate
from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
FitCriterionFactoryProtocol,
)
from baybe.surrogates.gaussian_process.components.generic import (
GPComponentType,
to_component_factory,
Expand All @@ -35,6 +39,7 @@
GaussianProcessPreset,
)
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBEFitCriterionFactory,
BayBEKernelFactory,
BayBELikelihoodFactory,
BayBEMeanFactory,
Expand Down Expand Up @@ -178,6 +183,21 @@ class GaussianProcessSurrogate(Surrogate):
* :class:`gpytorch.likelihoods.Likelihood`
"""

criterion_factory: FitCriterionFactoryProtocol = field(
alias="criterion_or_factory",
factory=BayBEFitCriterionFactory,
converter=partial( # type: ignore[misc]
to_component_factory, component_type=GPComponentType.CRITERION
),
validator=is_callable(),
)
"""The fitting criterion for Gaussian process hyperparameter optimization.

Accepts:
* :class:`.components.fit_criterion.FitCriterion`
* :class:`.components.fit_criterion.FitCriterionFactoryProtocol`
"""

# TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently
# omitted due to: https://github.com/python-attrs/cattrs/issues/531
_model = field(init=False, default=None, eq=False)
Expand All @@ -195,6 +215,7 @@ def from_preset(
likelihood_or_factory: LikelihoodFactoryProtocol
| GPyTorchLikelihood
| None = None,
criterion_or_factory: FitCriterion | FitCriterionFactoryProtocol | None = None,
) -> Self:
"""Create a Gaussian process surrogate from one of the defined presets."""
preset = GaussianProcessPreset(preset)
Expand All @@ -204,13 +225,16 @@ def from_preset(
)
module = importlib.import_module(module_name)

kernel = kernel_or_factory or getattr(module, "PresetKernelFactory")()
mean = mean_or_factory or getattr(module, "PresetMeanFactory")()
likelihood = (
likelihood_or_factory or getattr(module, "PresetLikelihoodFactory")()
kernel = kernel_or_factory or getattr(module, "PRESET_KERNEL_FACTORY")
mean = mean_or_factory or getattr(module, "PRESET_MEAN_FACTORY")
likelihood = likelihood_or_factory or getattr(
module, "PRESET_LIKELIHOOD_FACTORY"
)
criterion = criterion_or_factory or getattr(
module, "PRESET_FIT_CRITERION_FACTORY"
)

return cls(kernel, mean, likelihood)
return cls(kernel, mean, likelihood, criterion)

@override
def to_botorch(self) -> GPyTorchModel:
Expand All @@ -237,7 +261,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
import botorch
import gpytorch
from botorch.models.transforms import Normalize, Standardize

assert self._searchspace is not None # provided by base class
Expand Down Expand Up @@ -281,6 +304,9 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
### Likelihood
likelihood = self.likelihood_factory(context.searchspace, train_x, train_y)

### Criterion
criterion = self.criterion_factory(context.searchspace, train_x, train_y)

### Model construction and fitting
self._model = botorch.models.SingleTaskGP(
train_x,
Expand All @@ -291,18 +317,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
covar_module=kernel,
likelihood=likelihood,
)

# TODO: This is still a temporary workaround to avoid overfitting seen in
# low-dimensional TL cases. More robust settings are being researched.
if context.n_task_dimensions > 0:
mll = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
self._model.likelihood, self._model
)
else:
mll = gpytorch.ExactMarginalLogLikelihood(
self._model.likelihood, self._model
)

mll = criterion.to_gpytorch(self._model.likelihood, self._model)
botorch.fit.fit_gpytorch_mll(mll)

@override
Expand All @@ -311,6 +326,9 @@ def __str__(self) -> str:
to_string("Kernel factory", self.kernel_factory, single_line=True),
to_string("Mean factory", self.mean_factory, single_line=True),
to_string("Likelihood factory", self.likelihood_factory, single_line=True),
to_string(
"Fit criterion factory", self.criterion_factory, single_line=True
),
]
return to_string(super().__str__(), *fields)

Expand Down
16 changes: 15 additions & 1 deletion baybe/surrogates/gaussian_process/presets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,60 @@
"""Gaussian process surrogate presets."""

# Criterion
from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion

# Default preset
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBEFitCriterionFactory,
BayBEKernelFactory,
BayBELikelihoodFactory,
BayBEMeanFactory,
)

# Chen preset
from baybe.surrogates.gaussian_process.presets.chen import CHENKernelFactory
from baybe.surrogates.gaussian_process.presets.chen import (
CHENFitCriterionFactory,
CHENKernelFactory,
)

# Core
from baybe.surrogates.gaussian_process.presets.core import GaussianProcessPreset

# EDBO preset
from baybe.surrogates.gaussian_process.presets.edbo import (
EDBOFitCriterionFactory,
EDBOKernelFactory,
EDBOLikelihoodFactory,
EDBOMeanFactory,
)

# Smoothed EDBO preset
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
SmoothedEDBOFitCriterionFactory,
SmoothedEDBOKernelFactory,
SmoothedEDBOLikelihoodFactory,
SmoothedEDBOMeanFactory,
)

__all__ = [
# Core
"FitCriterion",
"GaussianProcessPreset",
# Default BayBE preset
"BayBEFitCriterionFactory",
"BayBEKernelFactory",
"BayBELikelihoodFactory",
"BayBEMeanFactory",
# Chen preset
"CHENFitCriterionFactory",
"CHENKernelFactory",
# EDBO preset
"EDBOFitCriterionFactory",
"EDBOKernelFactory",
"EDBOLikelihoodFactory",
"EDBOMeanFactory",
# Smoothed EDBO preset
"SmoothedEDBOFitCriterionFactory",
"SmoothedEDBOKernelFactory",
"SmoothedEDBOLikelihoodFactory",
"SmoothedEDBOMeanFactory",
Expand Down
29 changes: 25 additions & 4 deletions baybe/surrogates/gaussian_process/presets/baybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
to_parameter_selector,
)
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
FitCriterionFactoryProtocol,
)
from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory
from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
Expand Down Expand Up @@ -85,7 +89,24 @@ def _make(
BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory
"""The factory providing the default likelihood for Gaussian process surrogates."""

# Aliases for generic preset imports
PresetKernelFactory = BayBEKernelFactory
PresetMeanFactory = BayBEMeanFactory
PresetLikelihoodFactory = BayBELikelihoodFactory

@define
class BayBEFitCriterionFactory(FitCriterionFactoryProtocol):
"""The factory providing the default fitting criterion for Gaussian process surrogates.""" # noqa: E501

@override
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> FitCriterion:
return (
FitCriterion.MARGINAL_LOG_LIKELIHOOD
if searchspace.task_idx is None
else FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD
)


# Preset defaults
PRESET_KERNEL_FACTORY = BayBEKernelFactory()
PRESET_MEAN_FACTORY = BayBEMeanFactory()
PRESET_LIKELIHOOD_FACTORY = BayBELikelihoodFactory()
PRESET_FIT_CRITERION_FACTORY = BayBEFitCriterionFactory()
16 changes: 12 additions & 4 deletions baybe/surrogates/gaussian_process/presets/chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
to_parameter_selector,
)
from baybe.priors.basic import GammaPrior
from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
PlainFitCriterionFactory,
)
from baybe.surrogates.gaussian_process.components.kernel import (
_PureKernelFactory,
)
Expand Down Expand Up @@ -68,10 +72,14 @@ def _make(
)


CHENFitCriterionFactory = PlainFitCriterionFactory(FitCriterion.MARGINAL_LOG_LIKELIHOOD)
"""A factory providing fitting criteria for the CHEN preset."""

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()

# Aliases for generic preset imports
PresetKernelFactory = CHENKernelFactory
PresetMeanFactory = LazyConstantMeanFactory
PresetLikelihoodFactory = LazyGaussianLikelihoodFactory
# Preset defaults
PRESET_KERNEL_FACTORY = CHENKernelFactory()
PRESET_MEAN_FACTORY = LazyConstantMeanFactory()
PRESET_LIKELIHOOD_FACTORY = LazyGaussianLikelihoodFactory()
PRESET_FIT_CRITERION_FACTORY = CHENFitCriterionFactory
Loading
Loading