diff --git a/bofire/data_models/surrogates/api.py b/bofire/data_models/surrogates/api.py index 9344944b7..c3191ebe6 100644 --- a/bofire/data_models/surrogates/api.py +++ b/bofire/data_models/surrogates/api.py @@ -15,7 +15,10 @@ FullyBayesianSingleTaskGPSurrogate, ) from bofire.data_models.surrogates.linear import LinearSurrogate -from bofire.data_models.surrogates.map_saas import AdditiveMapSaasSingleTaskGPSurrogate +from bofire.data_models.surrogates.map_saas import ( + AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, +) from bofire.data_models.surrogates.mixed_single_task_gp import ( MixedSingleTaskGPHyperconfig, MixedSingleTaskGPSurrogate, @@ -72,6 +75,7 @@ SingleTaskIBNNSurrogate, PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, ] AnyTrainableSurrogate = Union[ @@ -88,6 +92,7 @@ TanimotoGPSurrogate, PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, ] AnyRegressionSurrogate = Union[ @@ -106,6 +111,7 @@ SingleTaskIBNNSurrogate, PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, ] AnyClassificationSurrogate = ClassificationMLPEnsemble diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index d9263ae37..bffb016e1 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -14,7 +14,10 @@ FullyBayesianSingleTaskGPSurrogate, ) from bofire.data_models.surrogates.linear import LinearSurrogate -from bofire.data_models.surrogates.map_saas import AdditiveMapSaasSingleTaskGPSurrogate +from bofire.data_models.surrogates.map_saas import ( + AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, +) from bofire.data_models.surrogates.mixed_single_task_gp import ( MixedSingleTaskGPSurrogate, ) @@ -47,6 +50,7 @@ MultiTaskGPSurrogate, PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, ] diff --git a/bofire/data_models/surrogates/map_saas.py b/bofire/data_models/surrogates/map_saas.py index 56ffae5a8..7be4e43f4 100644 --- a/bofire/data_models/surrogates/map_saas.py +++ b/bofire/data_models/surrogates/map_saas.py @@ -20,7 +20,7 @@ class AdditiveMapSaasSingleTaskGPSurrogate(TrainableBotorchSurrogate): n_taus (PositiveInt): Number of sub-kernels to use in the SAAS model. """ - type: Literal["AdditiveMapSaasSingleTaskGPSurrogate"] = ( + type: Literal["AdditiveMapSaasSingleTaskGPSurrogate"] = ( # type: ignore "AdditiveMapSaasSingleTaskGPSurrogate" ) n_taus: PositiveInt = 4 @@ -34,3 +34,28 @@ def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: bool: True if the output type is valid for the surrogate chosen, False otherwise """ return isinstance(my_type, type(ContinuousOutput)) + + +class EnsembleMapSaasSingleTaskGPSurrogate(TrainableBotorchSurrogate): + """Ensemble MAP SAAS single-task GP + + Batched ensemble of ``SingleTaskGP``s with the Matern-5/2 kernel and a SAAS prior. + + Attributes: + n_taus (PositiveInt): Number of sub-kernels to use in the SAAS model. + """ + + type: Literal["EnsembleMapSaasSingleTaskGPSurrogate"] = ( # type: ignore + "EnsembleMapSaasSingleTaskGPSurrogate" + ) + n_taus: PositiveInt = 4 + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + Args: + my_type: continuous or categorical output + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return isinstance(my_type, type(ContinuousOutput)) diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index e2855112d..a9e36d2a0 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -1,7 +1,10 @@ from bofire.surrogates.botorch_surrogates import BotorchSurrogates from bofire.surrogates.deterministic import LinearDeterministicSurrogate from bofire.surrogates.empirical import EmpiricalSurrogate -from bofire.surrogates.map_saas import AdditiveMapSaasSingleTaskGPSurrogate +from bofire.surrogates.map_saas import ( + AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, +) from bofire.surrogates.mapper import map from bofire.surrogates.mlp import ( ClassificationMLPEnsemble, diff --git a/bofire/surrogates/map_saas.py b/bofire/surrogates/map_saas.py index ec4e6281f..6e655ec70 100644 --- a/bofire/surrogates/map_saas.py +++ b/bofire/surrogates/map_saas.py @@ -1,10 +1,15 @@ from typing import Dict, Optional +import numpy as np +import pandas as pd import torch from botorch.fit import fit_gpytorch_mll -from botorch.models.map_saas import AdditiveMapSaasSingleTaskGP +from botorch.models.map_saas import ( + AdditiveMapSaasSingleTaskGP, + EnsembleMapSaasSingleTaskGP, +) from botorch.models.transforms.input import InputTransform -from botorch.models.transforms.outcome import OutcomeTransform +from botorch.models.transforms.outcome import OutcomeTransform, Standardize from gpytorch.mlls import ExactMarginalLogLikelihood from bofire.data_models.enum import OutputFilteringEnum @@ -12,6 +17,7 @@ AdditiveMapSaasSingleTaskGPSurrogate as DataModel, ) from bofire.surrogates.botorch import TrainableBotorchSurrogate +from bofire.utils.torch_tools import tkwargs class AdditiveMapSaasSingleTaskGPSurrogate(TrainableBotorchSurrogate): @@ -46,3 +52,54 @@ def _fit_botorch( ) mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model) fit_gpytorch_mll(mll, options=self.training_specs, max_attempts=50) + + +class EnsembleMapSaasSingleTaskGPSurrogate(TrainableBotorchSurrogate): + def __init__( + self, + data_model: DataModel, + **kwargs, + ): + self.n_taus = data_model.n_taus + self.scaler = data_model.scaler + self.output_scaler = data_model.output_scaler + super().__init__(data_model=data_model, **kwargs) + + model: Optional[EnsembleMapSaasSingleTaskGP] = None + _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL + training_specs: Dict = {} + + def _fit_botorch( + self, + tX: torch.Tensor, + tY: torch.Tensor, + input_transform: Optional[InputTransform] = None, + outcome_transform: Optional[OutcomeTransform] = None, + **kwargs, + ): + # EnsembleMapSaasSingleTaskGP repeats the data to create a batch dimension + # The outcome_transform needs to have the correct batch_shape + if isinstance(outcome_transform, Standardize): + outcome_transform = Standardize( + m=tY.shape[-1], + batch_shape=torch.Size([self.n_taus]), + ) + self.model = EnsembleMapSaasSingleTaskGP( + train_X=tX, + train_Y=tY, + outcome_transform=outcome_transform, + input_transform=input_transform, + num_taus=self.n_taus, + ) + mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model) + fit_gpytorch_mll(mll, options=self.training_specs, max_attempts=50) + + def _predict(self, transformed_X: pd.DataFrame): + # transform to tensor + X = torch.from_numpy(transformed_X.values).to(**tkwargs) + with torch.no_grad(): + posterior = self.model.posterior(X=X, observation_noise=True) # type: ignore + + preds = posterior.mixture_mean.detach().numpy() + stds = np.sqrt(posterior.mixture_variance.detach().numpy()) + return preds, stds diff --git a/bofire/surrogates/mapper.py b/bofire/surrogates/mapper.py index 2ea7044bd..595dd678e 100644 --- a/bofire/surrogates/mapper.py +++ b/bofire/surrogates/mapper.py @@ -12,7 +12,10 @@ ) from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.fully_bayesian import FullyBayesianSingleTaskGPSurrogate -from bofire.surrogates.map_saas import AdditiveMapSaasSingleTaskGPSurrogate +from bofire.surrogates.map_saas import ( + AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, +) from bofire.surrogates.mlp import ClassificationMLPEnsemble, RegressionMLPEnsemble from bofire.surrogates.multi_task_gp import MultiTaskGPSurrogate from bofire.surrogates.random_forest import RandomForestSurrogate @@ -85,6 +88,7 @@ def map_MixedSingleTaskGPSurrogate( data_models.PiecewiseLinearGPSurrogate: PiecewiseLinearGPSurrogate, data_models.CategoricalDeterministicSurrogate: CategoricalDeterministicSurrogate, data_models.AdditiveMapSaasSingleTaskGPSurrogate: AdditiveMapSaasSingleTaskGPSurrogate, + data_models.EnsembleMapSaasSingleTaskGPSurrogate: EnsembleMapSaasSingleTaskGPSurrogate, } diff --git a/tests/bofire/data_models/specs/surrogates.py b/tests/bofire/data_models/specs/surrogates.py index fa35e953e..92cca07ca 100644 --- a/tests/bofire/data_models/specs/surrogates.py +++ b/tests/bofire/data_models/specs/surrogates.py @@ -201,6 +201,30 @@ }, ) +specs.add_valid( + models.EnsembleMapSaasSingleTaskGPSurrogate, + lambda: { + "inputs": Inputs( + features=[ + features.valid(ContinuousInput).obj(), + ], + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ], + ).model_dump(), + "engineered_features": EngineeredFeatures().model_dump(), + "n_taus": 4, + "scaler": ScalerEnum.NORMALIZE, + "output_scaler": ScalerEnum.STANDARDIZE, + "input_preprocessing_specs": {}, + "categorical_encodings": {}, + "hyperconfig": None, + "dump": None, + }, +) + specs.add_valid( models.FullyBayesianSingleTaskGPSurrogate, lambda: { diff --git a/tests/bofire/surrogates/test_map_saas.py b/tests/bofire/surrogates/test_map_saas.py index 40e1c3b33..2b7a31f8f 100644 --- a/tests/bofire/surrogates/test_map_saas.py +++ b/tests/bofire/surrogates/test_map_saas.py @@ -1,9 +1,15 @@ -from botorch.models.map_saas import AdditiveMapSaasSingleTaskGP +from botorch.models.map_saas import ( + AdditiveMapSaasSingleTaskGP, + EnsembleMapSaasSingleTaskGP, +) from pandas.testing import assert_frame_equal import bofire.surrogates.api as surrogates from bofire.benchmarks.single import Himmelblau -from bofire.data_models.surrogates.api import AdditiveMapSaasSingleTaskGPSurrogate +from bofire.data_models.surrogates.api import ( + AdditiveMapSaasSingleTaskGPSurrogate, + EnsembleMapSaasSingleTaskGPSurrogate, +) def test_AdditiveMapSaasSingleTaskGPSurrogate(): @@ -24,3 +30,23 @@ def test_AdditiveMapSaasSingleTaskGPSurrogate(): assert preds.shape == (10, 2) preds2 = gp.predict(experiments) assert_frame_equal(preds, preds2) + + +def test_EnsembleMapSaasSingleTaskGPSurrogate(): + bench = Himmelblau() + samples = bench.domain.inputs.sample(10) + experiments = bench.f(samples, return_complete=True) + data_model = EnsembleMapSaasSingleTaskGPSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + ) + gp = surrogates.map(data_model) + gp.fit(experiments=experiments) + assert isinstance(gp.model, EnsembleMapSaasSingleTaskGP) + dump = gp.dumps() + gp2 = surrogates.map(data_model=data_model) + gp2.loads(dump) + preds = gp.predict(experiments) + assert preds.shape == (10, 2) + preds2 = gp.predict(experiments) + assert_frame_equal(preds, preds2)