Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,30 @@ jobs:
- name: Run pip freeze
shell: bash -l {0}
run: pip freeze

testing_pfn:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.10', '3.11' ]
steps:
- name: Check out repo
uses: actions/checkout@v4

- name: Setup Conda
uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
activate-environment: test
python-version: ${{ matrix.python-version }}

- name: Install Dependencies
shell: bash -l {0}
run: |
conda install -c conda-forge cyipopt=1.5.0
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install ".[pfn,tests]"

- name: Run PFN tests
shell: bash -l {0}
run: pytest tests/bofire/surrogates/test_pfn.py -ra --cov=bofire.surrogates.pfn --cov-report term-missing
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Pragmatic Versioning](https://github.com/experiment

### Added

- `PFNSurrogate` - Prior-data Fitted Networks (PFN) surrogate model for Bayesian optimization using pre-trained transformers from the `pfns4bo` library. Includes support for both univariate and multivariate outputs with custom serialization for outcome transforms. **Note:** The `pfn` extras require Python 3.10 or 3.11 due to dependencies on older versions of scikit-learn (<1.2) that are incompatible with Python 3.12+.
- `CloneFeatures` engineered feature, that can be used to create a copy of a set of features, this can be useful if one wants to further process features differently (different scalers, different kernels etc.)
- Explicit Interaction features (like `x_1 * x_2`) for botorch based surrogates via the engineered features mechanism.
- Support for custom formulas including discrete and categorical features in the DoE module.
Expand Down
4 changes: 4 additions & 0 deletions bofire/data_models/surrogates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MultiTaskGPHyperconfig,
MultiTaskGPSurrogate,
)
from bofire.data_models.surrogates.pfn import PFNSurrogate
from bofire.data_models.surrogates.polynomial import PolynomialSurrogate
from bofire.data_models.surrogates.random_forest import RandomForestSurrogate
from bofire.data_models.surrogates.robust_single_task_gp import (
Expand Down Expand Up @@ -81,6 +82,7 @@
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
PFNSurrogate,
]

AnyTrainableSurrogate = Union[
Expand All @@ -98,6 +100,7 @@
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
PFNSurrogate,
]

AnyRegressionSurrogate = Union[
Expand All @@ -117,6 +120,7 @@
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
PFNSurrogate,
]

AnyClassificationSurrogate = ClassificationMLPEnsemble
2 changes: 2 additions & 0 deletions bofire/data_models/surrogates/botorch_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
RegressionMLPEnsemble,
)
from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.data_models.surrogates.pfn import PFNSurrogate
from bofire.data_models.surrogates.polynomial import PolynomialSurrogate
from bofire.data_models.surrogates.random_forest import RandomForestSurrogate
from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate
Expand All @@ -51,6 +52,7 @@
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
PFNSurrogate,
]


Expand Down
132 changes: 132 additions & 0 deletions bofire/data_models/surrogates/pfn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Any, Literal, Optional, Type

from pydantic import Field, field_validator

from bofire.data_models.enum import CategoricalEncodingEnum
from bofire.data_models.features.api import (
AnyOutput,
CategoricalDescriptorInput,
CategoricalInput,
CategoricalMolecularInput,
ContinuousOutput,
TaskInput,
)
from bofire.data_models.molfeatures.api import Fingerprints
from bofire.data_models.surrogates.scaler import AnyScaler, Normalize, ScalerEnum
from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate


class PFNSurrogate(TrainableBotorchSurrogate):
"""Prior-data Fitted Network (PFN) surrogate model.

PFN is a pre-trained neural network that can be used for Bayesian optimization
without requiring training on the specific task. The model is loaded from a
checkpoint URL and makes predictions based on training data context.

Attributes:
type: Discriminator for the surrogate type.
checkpoint_url: URL or path to the pre-trained PFN model checkpoint.
Defaults to the pfns4bo_hebo model. Can also use ModelPaths enum values:
- "pfns4bo_hebo": HEBO-style model with more budget and unused features
- "pfns4bo_bnn": BNN-style model sampled with warp for HPOB
Or provide a custom URL to a .pt.gz model file.
batch_first: Whether the batch dimension is the first dimension of input tensors.
For batch-first, X has shape `batch x seq_len x features`.
For non-batch-first, X has shape `seq_len x batch x features`.
constant_model_kwargs: Dictionary of constant keyword arguments that will be
passed to the model in each forward pass. Use this to configure model-specific
behavior during inference.
load_training_checkpoint: If True, loads a training checkpoint as produced by
the PFNs training code. If False, loads a pre-trained inference model.
cache_dir: Directory path for caching downloaded models. If None, uses
/tmp/botorch_pfn_models.
multivariate: If True, uses MultivariatePFNModel which returns a joint posterior
over batch inputs. This requires an additional forward pass and approximation.
If False, uses standard PFNModel with independent predictions.
scaler: Scaler to use for input features.
output_scaler: Scaler to use for output targets.

Note:
PFN models are pre-trained and do not require fitting in the traditional sense.
The "fit" operation simply loads the model and stores the training data as context
for inference.
"""

type: Literal["PFNSurrogate"] = "PFNSurrogate"

# Model loading configuration
checkpoint_url: str = "pfns4bo_hebo"
load_training_checkpoint: bool = False
cache_dir: Optional[str] = None

# Model architecture configuration
batch_first: bool = False
multivariate: bool = False

# Model inference configuration
constant_model_kwargs: dict[str, Any] = Field(default_factory=dict)
# num_samples: int = Field(
# default=128,
# description="Number of samples to draw from the posterior distribution for prediction.",
# )

# Scaling configuration
scaler: AnyScaler = Normalize()
output_scaler: ScalerEnum = ScalerEnum.STANDARDIZE

@field_validator("output_scaler")
@classmethod
def validate_output_scaler(cls, v: ScalerEnum) -> ScalerEnum:
"""Validate that output_scaler is not LOG or CHAINED_LOG_STANDARDIZE.

PFN models return variance estimates that are incompatible with LOG transforms,
as BoTorch's Log transform does not support untransforming variance.

Args:
v: The output scaler value to validate.

Returns:
The validated output scaler value.

Raises:
ValueError: If output_scaler is LOG or CHAINED_LOG_STANDARDIZE.
"""
if v in (ScalerEnum.LOG, ScalerEnum.CHAINED_LOG_STANDARDIZE):
raise ValueError(
f"PFNSurrogate does not support output_scaler={v.name}. "
"LOG and CHAINED_LOG_STANDARDIZE transforms are incompatible with "
"PFN's variance predictions. Use STANDARDIZE, NORMALIZE, or IDENTITY instead."
)
return v

@classmethod
def _default_categorical_encodings(
cls,
) -> dict[Type[CategoricalInput], CategoricalEncodingEnum | Fingerprints]:
"""Override default categorical encodings for PFN models.

PFN models work better with ordinal encodings instead of one-hot encoding
to avoid exceeding the pretrained model's feature dimension limits.
Pretrained PFN checkpoints have fixed input dimensionality constraints.

Returns:
Dictionary mapping categorical input types to their encoding strategies.
"""
return {
CategoricalInput: CategoricalEncodingEnum.ORDINAL,
CategoricalMolecularInput: Fingerprints(),
CategoricalDescriptorInput: CategoricalEncodingEnum.DESCRIPTOR,
TaskInput: CategoricalEncodingEnum.ORDINAL,
}

@classmethod
def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool:
"""Check if the output type is implemented for this surrogate.

Args:
my_type: The output feature type to check.

Returns:
True if the output type is ContinuousOutput, False otherwise.
"""
return my_type == ContinuousOutput
1 change: 1 addition & 0 deletions bofire/surrogates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RegressionMLPEnsemble,
)
from bofire.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.surrogates.pfn import PFNSurrogate
from bofire.surrogates.random_forest import RandomForestSurrogate
from bofire.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate
Expand Down
2 changes: 2 additions & 0 deletions bofire/surrogates/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from bofire.surrogates.mlp import ClassificationMLPEnsemble, RegressionMLPEnsemble
from bofire.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.surrogates.pfn import PFNSurrogate
from bofire.surrogates.random_forest import RandomForestSurrogate
from bofire.surrogates.robust_single_task_gp import RobustSingleTaskGPSurrogate
from bofire.surrogates.shape import PiecewiseLinearGPSurrogate
Expand Down Expand Up @@ -115,6 +116,7 @@ def map_TanimotoGPSurrogate(
data_models.CategoricalDeterministicSurrogate: CategoricalDeterministicSurrogate,
data_models.AdditiveMapSaasSingleTaskGPSurrogate: AdditiveMapSaasSingleTaskGPSurrogate,
data_models.EnsembleMapSaasSingleTaskGPSurrogate: EnsembleMapSaasSingleTaskGPSurrogate,
data_models.PFNSurrogate: PFNSurrogate,
}


Expand Down
Loading
Loading