diff --git a/CHANGELOG.md b/CHANGELOG.md index de3165f616..3979ae4723 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Gaussian process component factories - Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process components, enabling full low-level customization +- Configurable fitting criterion for Gaussian process hyperparameter optimization - Factories for all Gaussian process components - `CHEN`, `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate` - `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories diff --git a/baybe/surrogates/gaussian_process/components/__init__.py b/baybe/surrogates/gaussian_process/components/__init__.py index a9e11f4afe..bb83f995e4 100644 --- a/baybe/surrogates/gaussian_process/components/__init__.py +++ b/baybe/surrogates/gaussian_process/components/__init__.py @@ -1,5 +1,10 @@ """Gaussian process surrogate components.""" +from baybe.surrogates.gaussian_process.components.fit_criterion import ( + FitCriterion, + FitCriterionFactoryProtocol, + PlainFitCriterionFactory, +) from baybe.surrogates.gaussian_process.components.kernel import ( KernelFactoryProtocol, PlainKernelFactory, @@ -15,6 +20,10 @@ ) __all__ = [ + # Fit Criterion + "FitCriterion", + "FitCriterionFactoryProtocol", + "PlainFitCriterionFactory", # Kernel "KernelFactoryProtocol", "PlainKernelFactory", diff --git a/baybe/surrogates/gaussian_process/components/fit_criterion.py b/baybe/surrogates/gaussian_process/components/fit_criterion.py new file mode 100644 index 0000000000..f3d3b364ed --- /dev/null +++ b/baybe/surrogates/gaussian_process/components/fit_criterion.py @@ -0,0 +1,74 @@ +"""Fitting criteria for the Gaussian process surrogate.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING + +from attrs import define +from typing_extensions import override + +if TYPE_CHECKING: + from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood + from gpytorch.mlls import MarginalLogLikelihood + from gpytorch.models import GP as GPyTorchModel + from torch import Tensor + + from baybe.searchspace.core import SearchSpace + + +class FitCriterion(Enum): + """Available fitting criteria for GP hyperparameter optimization.""" + + MARGINAL_LOG_LIKELIHOOD = "MARGINAL_LOG_LIKELIHOOD" + """Exact marginal log-likelihood.""" + + LEAVE_ONE_OUT_PSEUDOLIKELIHOOD = "LEAVE_ONE_OUT_PSEUDOLIKELIHOOD" + """Leave-one-out cross-validation pseudo-likelihood.""" + + def to_gpytorch( + self, likelihood: GPyTorchLikelihood, model: GPyTorchModel + ) -> MarginalLogLikelihood: + """Create the corresponding GPyTorch MLL object.""" + import gpytorch + + mll_class = { + FitCriterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood, + FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD: gpytorch.mlls.LeaveOneOutPseudoLikelihood, # noqa: E501 + }[self] + return mll_class(likelihood, model) + + +# Delayed import to avoid circular dependency +from baybe.surrogates.gaussian_process.components.generic import ( # noqa: E402 + GPComponentFactoryProtocol, + PlainGPComponentFactory, +) + +FitCriterionFactoryProtocol = GPComponentFactoryProtocol[FitCriterion] +"""A protocol defining the interface for fit criterion factories.""" + +PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion] +"""A trivial factory that returns a fixed fit criterion.""" + + +@define +class _MLLForNonTLFitCriterionFactory(FitCriterionFactoryProtocol): + """A fit criterion factory switching between MLL and BayBE default. + + In transfer learning contexts, delegates to + :class:`baybe.surrogates.gaussian_process.presets.baybe.BayBEFitCriterionFactory`. + """ + + @override + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> FitCriterion: + if searchspace.task_idx is None: + return FitCriterion.MARGINAL_LOG_LIKELIHOOD + + from baybe.surrogates.gaussian_process.presets.baybe import ( + BayBEFitCriterionFactory, + ) + + return BayBEFitCriterionFactory()(searchspace, train_x, train_y) diff --git a/baybe/surrogates/gaussian_process/components/generic.py b/baybe/surrogates/gaussian_process/components/generic.py index c1977146e5..52ed9a66cd 100644 --- a/baybe/surrogates/gaussian_process/components/generic.py +++ b/baybe/surrogates/gaussian_process/components/generic.py @@ -14,8 +14,9 @@ from baybe.searchspace import SearchSpace from baybe.serialization.core import block_serialization_hook, converter from baybe.serialization.mixin import SerialMixin +from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion -BayBEGPComponent: TypeAlias = Kernel +BayBEGPComponent: TypeAlias = Kernel | FitCriterion if TYPE_CHECKING: from gpytorch.kernels import Kernel as GPyTorchKernel @@ -44,15 +45,24 @@ class GPComponentType(Enum): LIKELIHOOD = "LIKELIHOOD" """Gaussian process likelihood.""" + CRITERION = "CRITERION" + """Gaussian process fitting criterion.""" + def get_types(self) -> tuple[type, ...]: """Get the accepted BayBE and GPyTorch types for this component.""" - types = [] + types: list[type[GPComponent]] = [] # Add BayBE type if applicable if self is GPComponentType.KERNEL: from baybe.kernels.base import Kernel types.append(Kernel) + elif self is GPComponentType.CRITERION: + from baybe.surrogates.gaussian_process.components.fit_criterion import ( + FitCriterion, + ) + + types.append(FitCriterion) # Add GPyTorch type if available if sys.modules.get("gpytorch") is not None: @@ -85,7 +95,7 @@ def _is_gpytorch_component_class(obj: Any, /) -> bool: def _validate_component(instance: Any, attribute: Attribute, value: Any) -> None: """Validate that an object is a BayBE or a GPyTorch GP component.""" - if isinstance(value, Kernel) or _is_gpytorch_component_class(type(value)): + if isinstance(value, BayBEGPComponent) or _is_gpytorch_component_class(type(value)): return raise TypeError( diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 2b1a3f361e..3a8d8c9adc 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -19,6 +19,10 @@ from baybe.parameters.categorical import TaskParameter from baybe.searchspace.core import SearchSpace from baybe.surrogates.base import Surrogate +from baybe.surrogates.gaussian_process.components.fit_criterion import ( + FitCriterion, + FitCriterionFactoryProtocol, +) from baybe.surrogates.gaussian_process.components.generic import ( GPComponentType, to_component_factory, @@ -35,6 +39,7 @@ GaussianProcessPreset, ) from baybe.surrogates.gaussian_process.presets.baybe import ( + BayBEFitCriterionFactory, BayBEKernelFactory, BayBELikelihoodFactory, BayBEMeanFactory, @@ -178,6 +183,21 @@ class GaussianProcessSurrogate(Surrogate): * :class:`gpytorch.likelihoods.Likelihood` """ + criterion_factory: FitCriterionFactoryProtocol = field( + alias="criterion_or_factory", + factory=BayBEFitCriterionFactory, + converter=partial( # type: ignore[misc] + to_component_factory, component_type=GPComponentType.CRITERION + ), + validator=is_callable(), + ) + """The fitting criterion for Gaussian process hyperparameter optimization. + + Accepts: + * :class:`.components.fit_criterion.FitCriterion` + * :class:`.components.fit_criterion.FitCriterionFactoryProtocol` + """ + # TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently # omitted due to: https://github.com/python-attrs/cattrs/issues/531 _model = field(init=False, default=None, eq=False) @@ -195,6 +215,7 @@ def from_preset( likelihood_or_factory: LikelihoodFactoryProtocol | GPyTorchLikelihood | None = None, + criterion_or_factory: FitCriterion | FitCriterionFactoryProtocol | None = None, ) -> Self: """Create a Gaussian process surrogate from one of the defined presets.""" preset = GaussianProcessPreset(preset) @@ -204,13 +225,12 @@ def from_preset( ) module = importlib.import_module(module_name) - kernel = kernel_or_factory or getattr(module, "PresetKernelFactory")() - mean = mean_or_factory or getattr(module, "PresetMeanFactory")() - likelihood = ( - likelihood_or_factory or getattr(module, "PresetLikelihoodFactory")() - ) + kernel = kernel_or_factory or getattr(module, "KERNEL_FACTORY") + mean = mean_or_factory or getattr(module, "MEAN_FACTORY") + likelihood = likelihood_or_factory or getattr(module, "LIKELIHOOD_FACTORY") + criterion = criterion_or_factory or getattr(module, "FIT_CRITERION_FACTORY") - return cls(kernel, mean, likelihood) + return cls(kernel, mean, likelihood, criterion) @override def to_botorch(self) -> GPyTorchModel: @@ -237,7 +257,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: @override def _fit(self, train_x: Tensor, train_y: Tensor) -> None: import botorch - import gpytorch from botorch.models.transforms import Normalize, Standardize assert self._searchspace is not None # provided by base class @@ -281,6 +300,9 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: ### Likelihood likelihood = self.likelihood_factory(context.searchspace, train_x, train_y) + ### Criterion + criterion = self.criterion_factory(context.searchspace, train_x, train_y) + ### Model construction and fitting self._model = botorch.models.SingleTaskGP( train_x, @@ -291,18 +313,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: covar_module=kernel, likelihood=likelihood, ) - - # TODO: This is still a temporary workaround to avoid overfitting seen in - # low-dimensional TL cases. More robust settings are being researched. - if context.n_task_dimensions > 0: - mll = gpytorch.mlls.LeaveOneOutPseudoLikelihood( - self._model.likelihood, self._model - ) - else: - mll = gpytorch.ExactMarginalLogLikelihood( - self._model.likelihood, self._model - ) - + mll = criterion.to_gpytorch(self._model.likelihood, self._model) botorch.fit.fit_gpytorch_mll(mll) @override @@ -311,6 +322,9 @@ def __str__(self) -> str: to_string("Kernel factory", self.kernel_factory, single_line=True), to_string("Mean factory", self.mean_factory, single_line=True), to_string("Likelihood factory", self.likelihood_factory, single_line=True), + to_string( + "Fit criterion factory", self.criterion_factory, single_line=True + ), ] return to_string(super().__str__(), *fields) diff --git a/baybe/surrogates/gaussian_process/presets/__init__.py b/baybe/surrogates/gaussian_process/presets/__init__.py index 434fbf560f..97d73aaae6 100644 --- a/baybe/surrogates/gaussian_process/presets/__init__.py +++ b/baybe/surrogates/gaussian_process/presets/__init__.py @@ -1,20 +1,28 @@ """Gaussian process surrogate presets.""" +# Criterion +from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion + # Default preset from baybe.surrogates.gaussian_process.presets.baybe import ( + BayBEFitCriterionFactory, BayBEKernelFactory, BayBELikelihoodFactory, BayBEMeanFactory, ) # Chen preset -from baybe.surrogates.gaussian_process.presets.chen import CHENKernelFactory +from baybe.surrogates.gaussian_process.presets.chen import ( + CHENFitCriterionFactory, + CHENKernelFactory, +) # Core from baybe.surrogates.gaussian_process.presets.core import GaussianProcessPreset # EDBO preset from baybe.surrogates.gaussian_process.presets.edbo import ( + EDBOFitCriterionFactory, EDBOKernelFactory, EDBOLikelihoodFactory, EDBOMeanFactory, @@ -22,6 +30,7 @@ # Smoothed EDBO preset from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( + SmoothedEDBOFitCriterionFactory, SmoothedEDBOKernelFactory, SmoothedEDBOLikelihoodFactory, SmoothedEDBOMeanFactory, @@ -29,18 +38,23 @@ __all__ = [ # Core + "FitCriterion", "GaussianProcessPreset", # Default BayBE preset + "BayBEFitCriterionFactory", "BayBEKernelFactory", "BayBELikelihoodFactory", "BayBEMeanFactory", # Chen preset + "CHENFitCriterionFactory", "CHENKernelFactory", # EDBO preset + "EDBOFitCriterionFactory", "EDBOKernelFactory", "EDBOLikelihoodFactory", "EDBOMeanFactory", # Smoothed EDBO preset + "SmoothedEDBOFitCriterionFactory", "SmoothedEDBOKernelFactory", "SmoothedEDBOLikelihoodFactory", "SmoothedEDBOMeanFactory", diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 690fc318cd..0cf6686064 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -17,6 +17,10 @@ to_parameter_selector, ) from baybe.searchspace.core import SearchSpace +from baybe.surrogates.gaussian_process.components.fit_criterion import ( + FitCriterion, + FitCriterionFactoryProtocol, +) from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( @@ -85,7 +89,24 @@ def _make( BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory """The factory providing the default likelihood for Gaussian process surrogates.""" -# Aliases for generic preset imports -PresetKernelFactory = BayBEKernelFactory -PresetMeanFactory = BayBEMeanFactory -PresetLikelihoodFactory = BayBELikelihoodFactory + +@define +class BayBEFitCriterionFactory(FitCriterionFactoryProtocol): + """The factory providing the default fitting criterion for Gaussian process surrogates.""" # noqa: E501 + + @override + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> FitCriterion: + return ( + FitCriterion.MARGINAL_LOG_LIKELIHOOD + if searchspace.task_idx is None + else FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD + ) + + +# Preset defaults +KERNEL_FACTORY = BayBEKernelFactory() +MEAN_FACTORY = BayBEMeanFactory() +LIKELIHOOD_FACTORY = BayBELikelihoodFactory() +FIT_CRITERION_FACTORY = BayBEFitCriterionFactory() diff --git a/baybe/surrogates/gaussian_process/presets/chen.py b/baybe/surrogates/gaussian_process/presets/chen.py index 5e90c4aa5c..e461bc0750 100644 --- a/baybe/surrogates/gaussian_process/presets/chen.py +++ b/baybe/surrogates/gaussian_process/presets/chen.py @@ -18,6 +18,9 @@ to_parameter_selector, ) from baybe.priors.basic import GammaPrior +from baybe.surrogates.gaussian_process.components.fit_criterion import ( + _MLLForNonTLFitCriterionFactory, +) from baybe.surrogates.gaussian_process.components.kernel import ( _PureKernelFactory, ) @@ -68,10 +71,14 @@ def _make( ) +CHENFitCriterionFactory = _MLLForNonTLFitCriterionFactory() +"""A factory providing fitting criteria for the CHEN preset.""" + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() -# Aliases for generic preset imports -PresetKernelFactory = CHENKernelFactory -PresetMeanFactory = LazyConstantMeanFactory -PresetLikelihoodFactory = LazyGaussianLikelihoodFactory +# Preset defaults +KERNEL_FACTORY = CHENKernelFactory() +MEAN_FACTORY = LazyConstantMeanFactory() +LIKELIHOOD_FACTORY = LazyGaussianLikelihoodFactory() +FIT_CRITERION_FACTORY = CHENFitCriterionFactory diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 539e2ef72c..1ff8d2bf80 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -21,6 +21,9 @@ from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete +from baybe.surrogates.gaussian_process.components.fit_criterion import ( + _MLLForNonTLFitCriterionFactory, +) from baybe.surrogates.gaussian_process.components.kernel import ( _PureKernelFactory, ) @@ -175,10 +178,14 @@ def __call__( return likelihood +EDBOFitCriterionFactory = _MLLForNonTLFitCriterionFactory() +"""A factory providing fitting criteria for the EDBO preset.""" + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() -# Aliases for generic preset imports -PresetKernelFactory = EDBOKernelFactory -PresetMeanFactory = EDBOMeanFactory -PresetLikelihoodFactory = EDBOLikelihoodFactory +# Preset defaults +KERNEL_FACTORY = EDBOKernelFactory() +MEAN_FACTORY = EDBOMeanFactory() +LIKELIHOOD_FACTORY = EDBOLikelihoodFactory() +FIT_CRITERION_FACTORY = EDBOFitCriterionFactory diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index 904f711f31..519713f9dc 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -18,6 +18,9 @@ to_parameter_selector, ) from baybe.priors.basic import GammaPrior +from baybe.surrogates.gaussian_process.components.fit_criterion import ( + _MLLForNonTLFitCriterionFactory, +) from baybe.surrogates.gaussian_process.components.kernel import ( _PureKernelFactory, ) @@ -126,10 +129,14 @@ def __call__( return likelihood +SmoothedEDBOFitCriterionFactory = _MLLForNonTLFitCriterionFactory() +"""A factory providing fitting criteria for the smoothed EDBO preset.""" + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() -# Aliases for generic preset imports -PresetKernelFactory = SmoothedEDBOKernelFactory -PresetMeanFactory = SmoothedEDBOMeanFactory -PresetLikelihoodFactory = SmoothedEDBOLikelihoodFactory +# Preset defaults +KERNEL_FACTORY = SmoothedEDBOKernelFactory() +MEAN_FACTORY = SmoothedEDBOMeanFactory() +LIKELIHOOD_FACTORY = SmoothedEDBOLikelihoodFactory() +FIT_CRITERION_FACTORY = SmoothedEDBOFitCriterionFactory diff --git a/tests/test_gp.py b/tests/test_gp.py index 30ca69bc0e..e5671984dd 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -14,6 +14,7 @@ from baybe.kernels.basic import MaternKernel, RBFKernel from baybe.kernels.composite import ScaleKernel from baybe.parameters.numerical import NumericalContinuousParameter +from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion from baybe.surrogates.gaussian_process.components.generic import PlainGPComponentFactory from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate from baybe.surrogates.gaussian_process.presets import GaussianProcessPreset @@ -89,24 +90,32 @@ def test_presets(preset: GaussianProcessPreset): kernel = GPyTorchMaternKernel() mean = ConstantMean() likelihood = GaussianLikelihood() + criterion = FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD # Works without overrides ... - GaussianProcessSurrogate.from_preset(preset) + gp1 = GaussianProcessSurrogate.from_preset(preset) # ... and with overrides - gp = GaussianProcessSurrogate.from_preset( + gp2 = GaussianProcessSurrogate.from_preset( preset, kernel_or_factory=kernel, mean_or_factory=mean, likelihood_or_factory=likelihood, + criterion_or_factory=criterion, ) - assert isinstance(gp.kernel_factory, PlainGPComponentFactory) - assert gp.kernel_factory.component is kernel - assert isinstance(gp.mean_factory, PlainGPComponentFactory) - assert gp.mean_factory.component is mean - assert isinstance(gp.likelihood_factory, PlainGPComponentFactory) - assert gp.likelihood_factory.component is likelihood - gp.fit(searchspace, objective, measurements) + + # Check that the overrides were applied correctly + assert isinstance(gp2.kernel_factory, PlainGPComponentFactory) + assert gp2.kernel_factory.component is kernel + assert isinstance(gp2.mean_factory, PlainGPComponentFactory) + assert gp2.mean_factory.component is mean + assert isinstance(gp2.likelihood_factory, PlainGPComponentFactory) + assert gp2.likelihood_factory.component is likelihood + assert isinstance(gp2.criterion_factory, PlainGPComponentFactory) + assert gp2.criterion_factory.component == criterion + assert gp2.criterion_factory != gp1.criterion_factory + + gp2.fit(searchspace, objective, measurements) def test_invalid_components(): @@ -116,4 +125,8 @@ def test_invalid_components(): with pytest.raises(TypeError, match="Component must be one of"): GaussianProcessSurrogate(mean_or_factory=GaussianLikelihood()) with pytest.raises(TypeError, match="Component must be one of"): - GaussianProcessSurrogate(likelihood_or_factory=MaternKernel()) + GaussianProcessSurrogate( + likelihood_or_factory=FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD + ) + with pytest.raises(TypeError, match="Component must be one of"): + GaussianProcessSurrogate(criterion_or_factory=MaternKernel())