diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index faab8ec50..b8cc20b16 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -122,3 +122,30 @@ jobs: - name: Run pip freeze shell: bash -l {0} run: pip freeze + + testing_pfn: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ '3.10', '3.11' ] + steps: + - name: Check out repo + uses: actions/checkout@v4 + + - name: Setup Conda + uses: conda-incubator/setup-miniconda@v3 + with: + miniconda-version: "latest" + activate-environment: test + python-version: ${{ matrix.python-version }} + + - name: Install Dependencies + shell: bash -l {0} + run: | + conda install -c conda-forge cyipopt=1.5.0 + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install ".[pfn,tests]" + + - name: Run PFN tests + shell: bash -l {0} + run: pytest tests/bofire/surrogates/test_pfn.py -ra --cov=bofire.surrogates.pfn --cov-report term-missing diff --git a/CHANGELOG.md b/CHANGELOG.md index 99dfde16a..55372bf25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Pragmatic Versioning](https://github.com/experiment ### Added +- `PFNSurrogate` - Prior-data Fitted Networks (PFN) surrogate model for Bayesian optimization using pre-trained transformers from the `pfns4bo` library. Includes support for both univariate and multivariate outputs with custom serialization for outcome transforms. **Note:** The `pfn` extras require Python 3.10 or 3.11 due to dependencies on older versions of scikit-learn (<1.2) that are incompatible with Python 3.12+. - `CloneFeatures` engineered feature, that can be used to create a copy of a set of features, this can be useful if one wants to further process features differently (different scalers, different kernels etc.) - Explicit Interaction features (like `x_1 * x_2`) for botorch based surrogates via the engineered features mechanism. - Support for custom formulas including discrete and categorical features in the DoE module. diff --git a/bofire/data_models/surrogates/api.py b/bofire/data_models/surrogates/api.py index 7a98b97af..f53ad467d 100644 --- a/bofire/data_models/surrogates/api.py +++ b/bofire/data_models/surrogates/api.py @@ -32,6 +32,7 @@ MultiTaskGPHyperconfig, MultiTaskGPSurrogate, ) +from bofire.data_models.surrogates.pfn import PFNSurrogate from bofire.data_models.surrogates.polynomial import PolynomialSurrogate from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.robust_single_task_gp import ( @@ -81,6 +82,7 @@ PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, EnsembleMapSaasSingleTaskGPSurrogate, + PFNSurrogate, ] AnyTrainableSurrogate = Union[ @@ -98,6 +100,7 @@ PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, EnsembleMapSaasSingleTaskGPSurrogate, + PFNSurrogate, ] AnyRegressionSurrogate = Union[ @@ -117,6 +120,7 @@ PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, EnsembleMapSaasSingleTaskGPSurrogate, + PFNSurrogate, ] AnyClassificationSurrogate = ClassificationMLPEnsemble diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index bffb016e1..92c3ffd81 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -26,6 +26,7 @@ RegressionMLPEnsemble, ) from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate +from bofire.data_models.surrogates.pfn import PFNSurrogate from bofire.data_models.surrogates.polynomial import PolynomialSurrogate from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate @@ -51,6 +52,7 @@ PiecewiseLinearGPSurrogate, AdditiveMapSaasSingleTaskGPSurrogate, EnsembleMapSaasSingleTaskGPSurrogate, + PFNSurrogate, ] diff --git a/bofire/data_models/surrogates/pfn.py b/bofire/data_models/surrogates/pfn.py new file mode 100644 index 000000000..0ce78385f --- /dev/null +++ b/bofire/data_models/surrogates/pfn.py @@ -0,0 +1,132 @@ +from typing import Any, Literal, Optional, Type + +from pydantic import Field, field_validator + +from bofire.data_models.enum import CategoricalEncodingEnum +from bofire.data_models.features.api import ( + AnyOutput, + CategoricalDescriptorInput, + CategoricalInput, + CategoricalMolecularInput, + ContinuousOutput, + TaskInput, +) +from bofire.data_models.molfeatures.api import Fingerprints +from bofire.data_models.surrogates.scaler import AnyScaler, Normalize, ScalerEnum +from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate + + +class PFNSurrogate(TrainableBotorchSurrogate): + """Prior-data Fitted Network (PFN) surrogate model. + + PFN is a pre-trained neural network that can be used for Bayesian optimization + without requiring training on the specific task. The model is loaded from a + checkpoint URL and makes predictions based on training data context. + + Attributes: + type: Discriminator for the surrogate type. + checkpoint_url: URL or path to the pre-trained PFN model checkpoint. + Defaults to the pfns4bo_hebo model. Can also use ModelPaths enum values: + - "pfns4bo_hebo": HEBO-style model with more budget and unused features + - "pfns4bo_bnn": BNN-style model sampled with warp for HPOB + Or provide a custom URL to a .pt.gz model file. + batch_first: Whether the batch dimension is the first dimension of input tensors. + For batch-first, X has shape `batch x seq_len x features`. + For non-batch-first, X has shape `seq_len x batch x features`. + constant_model_kwargs: Dictionary of constant keyword arguments that will be + passed to the model in each forward pass. Use this to configure model-specific + behavior during inference. + load_training_checkpoint: If True, loads a training checkpoint as produced by + the PFNs training code. If False, loads a pre-trained inference model. + cache_dir: Directory path for caching downloaded models. If None, uses + /tmp/botorch_pfn_models. + multivariate: If True, uses MultivariatePFNModel which returns a joint posterior + over batch inputs. This requires an additional forward pass and approximation. + If False, uses standard PFNModel with independent predictions. + scaler: Scaler to use for input features. + output_scaler: Scaler to use for output targets. + + Note: + PFN models are pre-trained and do not require fitting in the traditional sense. + The "fit" operation simply loads the model and stores the training data as context + for inference. + """ + + type: Literal["PFNSurrogate"] = "PFNSurrogate" + + # Model loading configuration + checkpoint_url: str = "pfns4bo_hebo" + load_training_checkpoint: bool = False + cache_dir: Optional[str] = None + + # Model architecture configuration + batch_first: bool = False + multivariate: bool = False + + # Model inference configuration + constant_model_kwargs: dict[str, Any] = Field(default_factory=dict) + # num_samples: int = Field( + # default=128, + # description="Number of samples to draw from the posterior distribution for prediction.", + # ) + + # Scaling configuration + scaler: AnyScaler = Normalize() + output_scaler: ScalerEnum = ScalerEnum.STANDARDIZE + + @field_validator("output_scaler") + @classmethod + def validate_output_scaler(cls, v: ScalerEnum) -> ScalerEnum: + """Validate that output_scaler is not LOG or CHAINED_LOG_STANDARDIZE. + + PFN models return variance estimates that are incompatible with LOG transforms, + as BoTorch's Log transform does not support untransforming variance. + + Args: + v: The output scaler value to validate. + + Returns: + The validated output scaler value. + + Raises: + ValueError: If output_scaler is LOG or CHAINED_LOG_STANDARDIZE. + """ + if v in (ScalerEnum.LOG, ScalerEnum.CHAINED_LOG_STANDARDIZE): + raise ValueError( + f"PFNSurrogate does not support output_scaler={v.name}. " + "LOG and CHAINED_LOG_STANDARDIZE transforms are incompatible with " + "PFN's variance predictions. Use STANDARDIZE, NORMALIZE, or IDENTITY instead." + ) + return v + + @classmethod + def _default_categorical_encodings( + cls, + ) -> dict[Type[CategoricalInput], CategoricalEncodingEnum | Fingerprints]: + """Override default categorical encodings for PFN models. + + PFN models work better with ordinal encodings instead of one-hot encoding + to avoid exceeding the pretrained model's feature dimension limits. + Pretrained PFN checkpoints have fixed input dimensionality constraints. + + Returns: + Dictionary mapping categorical input types to their encoding strategies. + """ + return { + CategoricalInput: CategoricalEncodingEnum.ORDINAL, + CategoricalMolecularInput: Fingerprints(), + CategoricalDescriptorInput: CategoricalEncodingEnum.DESCRIPTOR, + TaskInput: CategoricalEncodingEnum.ORDINAL, + } + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Check if the output type is implemented for this surrogate. + + Args: + my_type: The output feature type to check. + + Returns: + True if the output type is ContinuousOutput, False otherwise. + """ + return my_type == ContinuousOutput diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index a9e36d2a0..7dc9e0829 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -12,6 +12,7 @@ RegressionMLPEnsemble, ) from bofire.surrogates.multi_task_gp import MultiTaskGPSurrogate +from bofire.surrogates.pfn import PFNSurrogate from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.shape import PiecewiseLinearGPSurrogate from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate diff --git a/bofire/surrogates/mapper.py b/bofire/surrogates/mapper.py index 34bddaba7..dd1c71886 100644 --- a/bofire/surrogates/mapper.py +++ b/bofire/surrogates/mapper.py @@ -18,6 +18,7 @@ ) from bofire.surrogates.mlp import ClassificationMLPEnsemble, RegressionMLPEnsemble from bofire.surrogates.multi_task_gp import MultiTaskGPSurrogate +from bofire.surrogates.pfn import PFNSurrogate from bofire.surrogates.random_forest import RandomForestSurrogate from bofire.surrogates.robust_single_task_gp import RobustSingleTaskGPSurrogate from bofire.surrogates.shape import PiecewiseLinearGPSurrogate @@ -115,6 +116,7 @@ def map_TanimotoGPSurrogate( data_models.CategoricalDeterministicSurrogate: CategoricalDeterministicSurrogate, data_models.AdditiveMapSaasSingleTaskGPSurrogate: AdditiveMapSaasSingleTaskGPSurrogate, data_models.EnsembleMapSaasSingleTaskGPSurrogate: EnsembleMapSaasSingleTaskGPSurrogate, + data_models.PFNSurrogate: PFNSurrogate, } diff --git a/bofire/surrogates/pfn.py b/bofire/surrogates/pfn.py new file mode 100644 index 000000000..24d1ba7e2 --- /dev/null +++ b/bofire/surrogates/pfn.py @@ -0,0 +1,223 @@ +import base64 +import io +import warnings +from typing import Optional, Tuple + +import numpy as np +import pandas as pd +import torch +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform + +from bofire.data_models.enum import OutputFilteringEnum +from bofire.data_models.surrogates.api import PFNSurrogate as DataModel +from bofire.surrogates.botorch import TrainableBotorchSurrogate +from bofire.utils.torch_tools import tkwargs + + +# Import PFN models from botorch_community +try: + from botorch_community.models.prior_fitted_network import ( + MultivariatePFNModel, + PFNModel, + ) + from botorch_community.models.utils.prior_fitted_network import ModelPaths +except ImportError: + warnings.warn( + "botorch_community not installed, PFN models cannot be used.", + ImportWarning, + ) + + +class PFNSurrogate(TrainableBotorchSurrogate): + """Prior-data Fitted Network (PFN) surrogate for Bayesian optimization. + + PFN is a pre-trained transformer-based model that can make predictions + in a zero-shot manner by conditioning on training data. Unlike traditional + surrogates, PFN doesn't require gradient-based training on the specific task. + + The model is loaded from a checkpoint and makes predictions by processing + the training data as context along with test points. + + Attributes: + checkpoint_url: URL or path to the pre-trained PFN checkpoint. + batch_first: Whether batch dimension comes first in tensors. + multivariate: If True, uses MultivariatePFNModel for joint posteriors. + constant_model_kwargs: Additional kwargs passed to model during inference. + load_training_checkpoint: Whether to load a training checkpoint format. + cache_dir: Directory for caching downloaded models. + model: The underlying PFN model (PFNModel or MultivariatePFNModel). + """ + + def __init__(self, data_model: DataModel, **kwargs): + """Initialize the PFN surrogate. + + Args: + data_model: The PFNSurrogate data model with configuration. + **kwargs: Additional arguments passed to parent class. + """ + self.checkpoint_url = data_model.checkpoint_url + self.batch_first = data_model.batch_first + self.multivariate = data_model.multivariate + self.constant_model_kwargs = data_model.constant_model_kwargs + self.load_training_checkpoint = data_model.load_training_checkpoint + self.cache_dir = data_model.cache_dir + self.scaler = data_model.scaler + self.output_scaler = data_model.output_scaler + # self.num_samples = data_model.num_samples + super().__init__(data_model, **kwargs) + + _output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL + model: Optional[PFNModel] = None + + def _fit_botorch( + self, + tX: torch.Tensor, + tY: torch.Tensor, + input_transform: Optional[InputTransform] = None, + outcome_transform: Optional[OutcomeTransform] = None, + **kwargs, + ) -> None: + """Fit the PFN model to training data. + + For PFN, "fitting" means loading the pre-trained model and storing + the training data as context. The model itself is not trained. + + Args: + tX: Training features of shape (n, d). + tY: Training targets of shape (n, 1). + input_transform: Optional input transform to apply. + outcome_transform: Optional outcome transform to apply. + **kwargs: Additional keyword arguments (unused). + """ + # Convert checkpoint_url to ModelPaths enum if it's a known path + checkpoint_path = self.checkpoint_url + if checkpoint_path == "pfns4bo_hebo": + checkpoint_path = ModelPaths.pfns4bo_hebo + elif checkpoint_path == "pfns4bo_bnn": + checkpoint_path = ModelPaths.pfns4bo_bnn + + # Apply outcome transform if provided, but NOT input transform + # Input transform will be handled by the PFN model itself + if outcome_transform is not None: + tY = outcome_transform(tY)[0] + + # Ensure data is on the correct device and has correct dtype + tX = tX.to(**tkwargs) + tY = tY.to(**tkwargs) + + # Ensure tY is 2-dimensional (n, 1) + if tY.dim() == 1: + tY = tY.unsqueeze(-1) + + # Select model class based on multivariate flag + model_class = MultivariatePFNModel if self.multivariate else PFNModel + + # Initialize the PFN model + # Note: We pass model=None to download from checkpoint_url + # The model will automatically download and cache the checkpoint + # We pass input_transform to the model so it can handle transformations + # internally during both initialization and posterior calls + self.model = model_class( + train_X=tX, + train_Y=tY, + model=None, # Will be downloaded from checkpoint_url + checkpoint_url=checkpoint_path, + train_Yvar=None, # PFN doesn't use noise variance + batch_first=self.batch_first, + constant_model_kwargs=self.constant_model_kwargs, + input_transform=input_transform, # Let PFN handle the transform + load_training_checkpoint=self.load_training_checkpoint, + ) + + # Store outcome transform if provided + # PFN doesn't directly support outcome transforms, but we can apply them + # in the prediction step if needed + if outcome_transform is not None: + self._outcome_transform = outcome_transform + + def _predict(self, transformed_X: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: + """Predict using the PFN model. + + The PFN model handles input transformation internally via its + input_transform attribute, so we just convert to tensor and call posterior. + Predictions are made by drawing samples from the posterior distribution, + applying outcome transformation if configured, and computing statistics. + + Args: + transformed_X: Input features as a pandas DataFrame. + + Returns: + Tuple of predictions and standard deviations as numpy arrays. + """ + # Convert to tensor + X = torch.from_numpy(transformed_X.values).to(**tkwargs) + + # The model will apply input_transform internally in posterior() if needed + with torch.no_grad(): + posterior = self.model.posterior(X=X, observation_noise=True) + + # Get mean and variance from the posterior + preds, var = ( + self._outcome_transform.untransform( + posterior.mean, posterior.variance, None + ) + if hasattr(self, "_outcome_transform") + and self._outcome_transform is not None + else (posterior.mean, posterior.variance) + ) + stds = np.sqrt(var.cpu().detach().numpy()) + preds = preds.cpu().detach().numpy() + return preds, stds + + # Override the _dumps and _loads methods to handle the fact that PFN models are not + # standard PyTorch models and may have additional state (outcome transform) that needs + # to be saved and loaded. + # TODO: PFN implementation in botorch_community uses posterior_transform. + # However, this is not currently implemented in the botorch impementation. + # Add this functionality once the botorch implementation is updated to match the + # botorch_community implementation. + def _dumps(self) -> str: + """Dumps the model and outcome transform to a string. + + Overrides the parent method to also save the outcome transform, + which is stored separately from the model. + """ + # Serialize the model + self.model.prediction_strategy = None + buffer = io.BytesIO() + torch.save(self.model, buffer) + model_bytes = base64.b64encode(buffer.getvalue()).decode() + + # Serialize the outcome transform if it exists + outcome_transform_bytes = None + if hasattr(self, "_outcome_transform") and self._outcome_transform is not None: + buffer = io.BytesIO() + torch.save(self._outcome_transform, buffer) + outcome_transform_bytes = base64.b64encode(buffer.getvalue()).decode() + + # Combine both into a single string with a delimiter + if outcome_transform_bytes: + return f"{model_bytes}|||OUTCOME_TRANSFORM|||{outcome_transform_bytes}" + return model_bytes + + def loads(self, data: str) -> None: + """Loads the model and outcome transform from a string. + + Overrides the parent method to also restore the outcome transform. + """ + # Check if outcome transform is included + if "|||OUTCOME_TRANSFORM|||" in data: + model_bytes, outcome_transform_bytes = data.split("|||OUTCOME_TRANSFORM|||") + + # Load the model + buffer = io.BytesIO(base64.b64decode(model_bytes.encode())) + self.model = torch.load(buffer, weights_only=False) + + # Load the outcome transform + buffer = io.BytesIO(base64.b64decode(outcome_transform_bytes.encode())) + self._outcome_transform = torch.load(buffer, weights_only=False) + else: + # No outcome transform, just load the model + buffer = io.BytesIO(base64.b64decode(data.encode())) + self.model = torch.load(buffer, weights_only=False) diff --git a/pyproject.toml b/pyproject.toml index 7eb9c797e..31ba64e02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "pydantic>=2.5", "scipy>=1.7", "typing-extensions", - "formulaic<=1.0.2", + "formulaic==1.0.2", ] [project.optional-dependencies] @@ -39,10 +39,10 @@ optimization = [ "numpy", "multiprocess", "plotly", - "formulaic>=1.0.1,<1.1", + "formulaic==1.0.2", "cloudpickle>=2.0.0", "sympy>=1.12", - "cvxpy[CLARABEL,SCIP]", + "cvxpy[CLARABEL,SCIP]>=1.5,<1.8", "scikit-learn>=1.0.0", "pymoo>=0.6.0", "shap>=0.48.0", @@ -52,22 +52,46 @@ entmoot = [ "entmoot>=2.0.6", ] # we pin the pyomo version here due to compatibility issues cheminfo = ["rdkit>=2023.3.2", "scikit-learn>=1.0.0", "mordredcommunity>=2.0.1"] +# pfn4bo requires scikitlearn<1.2. This version is not compatible with numpy 2.X +# We modify numpy, scikit-learn, cvpxy and shap to meet the pfn4bo requirements. +# NOTE: scikit-learn<1.2 is not compatible with Python 3.12+ due to removal of pkgutil.ImpImporter. +# Therefore, pfn extras should only be installed on Python 3.10 or 3.11. +pfn = [ + "botorch>=0.16.1; python_version < '3.12'", + "numpy>=1.22,<2.0; python_version < '3.12'", + "multiprocess; python_version < '3.12'", + "plotly; python_version < '3.12'", + "formulaic==1.0.2; python_version < '3.12'", + "cloudpickle>=2.0.0; python_version < '3.12'", + "sympy>=1.12; python_version < '3.12'", + "cvxpy[CLARABEL,SCIP]>=1.5,<1.8; python_version < '3.12'", + "scikit-learn>=0.24.2,<1.2; python_version < '3.12'", + "pymoo>=0.6.0; python_version < '3.12'", + "shap<0.50.0; python_version < '3.12'", + "pfns4bo; python_version < '3.12'", + "requests; python_version < '3.12'", + "pfns @ git+https://github.com/automl/PFNs.git", +] tests = ["pytest", "pytest-cov", "papermill"] docs = ["jupyter", "jupyter-cache", "matplotlib", "seaborn"] all = [ "botorch>=0.16.1", - "numpy", + "numpy>=1.22,<2.0; python_version < '3.12'", + "numpy; python_version >= '3.12'", "multiprocess", "plotly", - "formulaic==1.0.1", + "formulaic==1.0.2", "cloudpickle>=2.0.0", "sympy>=1.12", - "cvxpy[CLARABEL,SCIP]", - "scikit-learn>=1.0.0", + "scikit-learn>=0.24.2,<1.2; python_version < '3.12'", + "cvxpy[CLARABEL,SCIP]>=1.5,<1.8; python_version < '3.12'", "entmoot>=2.0.6", "pyomo<=6.9.4", "rdkit>=2023.3.2", "mordredcommunity>=2.0.1", + "pfns4bo; python_version < '3.12'", + "requests", + "pfns @ git+https://github.com/automl/PFNs.git", "mopti", "pytest", "pytest-cov", @@ -82,7 +106,8 @@ all = [ "matplotlib", "seaborn", "pymoo>=0.6.0", - "shap>=0.48.0", + "shap<0.50.0; python_version < '3.12'", + "shap>=0.48.0; python_version >= '3.12'", ] [tool.setuptools.packages] diff --git a/tests/bofire/data_models/specs/surrogates.py b/tests/bofire/data_models/specs/surrogates.py index 4f0a725ae..082cec73c 100644 --- a/tests/bofire/data_models/specs/surrogates.py +++ b/tests/bofire/data_models/specs/surrogates.py @@ -28,7 +28,7 @@ THREESIX_SCALE_PRIOR, LogNormalPrior, ) -from bofire.data_models.surrogates.api import Normalize, ScalerEnum +from bofire.data_models.surrogates.api import Normalize, ScalerEnum, Standardize from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPHyperconfig from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogateHyperconfig from bofire.data_models.surrogates.single_task_gp import SingleTaskGPHyperconfig @@ -1171,3 +1171,128 @@ error=ValueError, message="Feature keys do not match input keys.", ) + +# PFN Surrogate specs +specs.add_valid( + models.PFNSurrogate, + lambda: { + "inputs": Inputs( + features=[ + ContinuousInput(key="a", bounds=[0, 1]), + ContinuousInput(key="b", bounds=[0, 1]), + ], + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ], + ).model_dump(), + "engineered_features": EngineeredFeatures().model_dump(), + "checkpoint_url": "pfns4bo_hebo", + "load_training_checkpoint": False, + "cache_dir": None, + "batch_first": False, + "multivariate": False, + "constant_model_kwargs": {}, + "scaler": Normalize().model_dump(), + "output_scaler": ScalerEnum.STANDARDIZE, + "input_preprocessing_specs": {}, + "categorical_encodings": {}, + "dump": None, + "hyperconfig": None, + }, +) + +specs.add_valid( + models.PFNSurrogate, + lambda: { + "inputs": Inputs( + features=[ + ContinuousInput(key="a", bounds=[0, 1]), + ContinuousInput(key="b", bounds=[0, 1]), + ], + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ], + ).model_dump(), + "engineered_features": EngineeredFeatures().model_dump(), + "checkpoint_url": "pfns4bo_bnn", + "load_training_checkpoint": False, + "cache_dir": None, + "batch_first": True, + "multivariate": True, + "constant_model_kwargs": {"key": "value"}, + "scaler": Standardize().model_dump(), + "output_scaler": ScalerEnum.IDENTITY, + "input_preprocessing_specs": {}, + "categorical_encodings": {}, + "dump": None, + "hyperconfig": None, + }, +) + +specs.add_invalid( + models.PFNSurrogate, + lambda: { + "inputs": Inputs( + features=[ + ContinuousInput(key="a", bounds=[0, 1]), + ], + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ], + ).model_dump(), + "checkpoint_url": "pfns4bo_hebo", + "load_training_checkpoint": False, + "cache_dir": None, + "batch_first": False, + "multivariate": False, + "constant_model_kwargs": {}, + "scaler": Normalize().model_dump(), + "output_scaler": ScalerEnum.LOG, + "input_preprocessing_specs": {}, + "categorical_encodings": {}, + "dump": None, + "hyperconfig": None, + }, + error=ValueError, + message="PFNSurrogate does not support output_scaler=LOG. " + "LOG and CHAINED_LOG_STANDARDIZE transforms are incompatible with " + "PFN's variance predictions. Use STANDARDIZE, NORMALIZE, or IDENTITY instead.", +) + +specs.add_invalid( + models.PFNSurrogate, + lambda: { + "inputs": Inputs( + features=[ + ContinuousInput(key="a", bounds=[0, 1]), + ], + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ], + ).model_dump(), + "checkpoint_url": "pfns4bo_hebo", + "load_training_checkpoint": False, + "cache_dir": None, + "batch_first": False, + "multivariate": False, + "constant_model_kwargs": {}, + "scaler": Normalize().model_dump(), + "output_scaler": ScalerEnum.CHAINED_LOG_STANDARDIZE, + "input_preprocessing_specs": {}, + "categorical_encodings": {}, + "dump": None, + "hyperconfig": None, + }, + error=ValueError, + message="PFNSurrogate does not support output_scaler=CHAINED_LOG_STANDARDIZE. " + "LOG and CHAINED_LOG_STANDARDIZE transforms are incompatible with " + "PFN's variance predictions. Use STANDARDIZE, NORMALIZE, or IDENTITY instead.", +) diff --git a/tests/bofire/surrogates/test_pfn.py b/tests/bofire/surrogates/test_pfn.py new file mode 100644 index 000000000..42e20ab51 --- /dev/null +++ b/tests/bofire/surrogates/test_pfn.py @@ -0,0 +1,323 @@ +import pytest +from pandas.testing import assert_frame_equal + +import bofire.surrogates.api as surrogates +from bofire.benchmarks.single import Himmelblau, PositiveHimmelblau +from bofire.data_models.domain.api import Inputs, Outputs +from bofire.data_models.features.api import ( + CategoricalInput, + ContinuousInput, + ContinuousOutput, +) +from bofire.data_models.surrogates.api import ( + Normalize, + PFNSurrogate, + ScalerEnum, + Standardize, +) + + +try: + from botorch_community.models.prior_fitted_network import ( + MultivariatePFNModel, + PFNModel, + ) + + BOTORCH_COMMUNITY_AVAILABLE = True +except ImportError: + BOTORCH_COMMUNITY_AVAILABLE = False + + +pytestmark = pytest.mark.skipif( + not BOTORCH_COMMUNITY_AVAILABLE, + reason="botorch_community not installed", +) + + +@pytest.mark.parametrize( + "scaler, output_scaler", + [ + [Normalize(), ScalerEnum.IDENTITY], + [Standardize(), ScalerEnum.STANDARDIZE], + [None, ScalerEnum.STANDARDIZE], + ], +) +def test_pfn_surrogate_fit(scaler, output_scaler): + """Test PFN surrogate fitting with different scalers.""" + bench = PositiveHimmelblau() + samples = bench.domain.inputs.sample(15) + experiments = bench.f(samples, return_complete=True) + + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=False, + multivariate=False, + scaler=scaler, + output_scaler=output_scaler, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Check input transforms + if scaler is None: + assert surrogate._input_transform is None + + # Check outcome transforms + if output_scaler == ScalerEnum.STANDARDIZE: + assert hasattr(surrogate, "_outcome_transform") + # Note: PFN handles outcome transforms differently than standard GP models + elif output_scaler == ScalerEnum.IDENTITY: + # May or may not have outcome transform depending on implementation + pass + + # Test predictions + preds = surrogate.predict(experiments) + + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + assert preds.shape[0] == experiments.shape[0] + + # Test serialization/deserialization + dump = surrogate.dumps() + surrogate2 = surrogates.map(pfn) + surrogate2.loads(dump) + preds2 = surrogate2.predict(experiments) + assert_frame_equal(preds, preds2) + + +def test_pfn_surrogate_multivariate(): + """Test PFN surrogate with multivariate posterior.""" + bench = Himmelblau() + samples = bench.domain.inputs.sample(15) + experiments = bench.f(samples, return_complete=True) + + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=False, + multivariate=True, # Enable multivariate posterior + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Check that MultivariatePFNModel is used + assert isinstance(surrogate.model, MultivariatePFNModel) + + # Test predictions + preds = surrogate.predict(experiments) + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + + +def test_pfn_surrogate_checkpoint_variants(): + """Test PFN surrogate with different checkpoint URLs.""" + bench = Himmelblau() + samples = bench.domain.inputs.sample(15) + experiments = bench.f(samples, return_complete=True) + + # Test with pfns4bo_bnn checkpoint + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_bnn", + batch_first=False, + multivariate=False, + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Check that model is PFNModel + assert isinstance(surrogate.model, PFNModel) + + # Test predictions + preds = surrogate.predict(experiments) + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + + +def test_pfn_surrogate_batch_first(): + """Test PFN surrogate with batch_first=True.""" + bench = Himmelblau() + samples = bench.domain.inputs.sample(15) + experiments = bench.f(samples, return_complete=True) + + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=True, # Batch dimension first + multivariate=False, + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Test predictions + preds = surrogate.predict(experiments) + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + + +def test_pfn_surrogate_categorical_input(): + """Test PFN surrogate with categorical inputs.""" + inputs = Inputs( + features=[ + ContinuousInput(key="x_1", bounds=(-4, 4)), + ContinuousInput(key="x_2", bounds=(-4, 4)), + CategoricalInput(key="x_cat", categories=["a", "b"]), + ], + ) + outputs = Outputs(features=[ContinuousOutput(key="y")]) + + experiments = inputs.sample(n=15) + experiments.eval("y=((x_1**2 + x_2 - 11)**2+(x_1 + x_2**2 -7)**2)", inplace=True) + experiments.loc[experiments.x_cat == "a", "y"] *= 2.0 + experiments.loc[experiments.x_cat == "b", "y"] /= 2.0 + experiments["valid_y"] = 1 + + pfn = PFNSurrogate( + inputs=inputs, + outputs=outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=False, + multivariate=False, + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Test predictions + preds = surrogate.predict(experiments) + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + assert preds.shape[0] == experiments.shape[0] + + +def test_pfn_surrogate_small_dataset(): + """Test PFN surrogate with very small dataset.""" + bench = Himmelblau() + samples = bench.domain.inputs.sample(5) # Very small dataset + experiments = bench.f(samples, return_complete=True) + + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=False, + multivariate=False, + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Test predictions + preds = surrogate.predict(experiments) + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + + +def test_pfn_surrogate_constant_model_kwargs(): + """Test PFN surrogate with constant_model_kwargs.""" + bench = Himmelblau() + samples = bench.domain.inputs.sample(15) + experiments = bench.f(samples, return_complete=True) + + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=False, + multivariate=False, + constant_model_kwargs={}, # Empty dict for now + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Test predictions + preds = surrogate.predict(experiments) + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + + +def test_pfn_surrogate_is_fitted(): + """Test PFN surrogate is_fitted property.""" + bench = Himmelblau() + samples = bench.domain.inputs.sample(15) + experiments = bench.f(samples, return_complete=True) + + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=False, + multivariate=False, + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + # Before fitting + assert not surrogate.is_fitted + + # After fitting + surrogate.fit(experiments=experiments) + assert surrogate.is_fitted + + # After serialization/deserialization + dump = surrogate.dumps() + surrogate2 = surrogates.map(pfn) + assert not surrogate2.is_fitted + surrogate2.loads(dump) + assert surrogate2.is_fitted + + +def test_pfn_surrogate_prediction_shapes(): + """Test PFN surrogate prediction output shapes.""" + bench = Himmelblau() + samples = bench.domain.inputs.sample(15) + experiments = bench.f(samples, return_complete=True) + + pfn = PFNSurrogate( + inputs=bench.domain.inputs, + outputs=bench.domain.outputs, + checkpoint_url="pfns4bo_hebo", + batch_first=False, + multivariate=False, + scaler=Normalize(), + output_scaler=ScalerEnum.STANDARDIZE, + ) + surrogate = surrogates.map(pfn) + + surrogate.fit(experiments=experiments) + + # Test predictions on training data + preds = surrogate.predict(experiments) + assert preds.shape[0] == experiments.shape[0] + assert "y_pred" in preds.columns + assert "y_sd" in preds.columns + + # Test predictions on new data + new_samples = bench.domain.inputs.sample(5) + new_preds = surrogate.predict(new_samples) + assert new_preds.shape[0] == new_samples.shape[0] + assert "y_pred" in new_preds.columns + assert "y_sd" in new_preds.columns