Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ notebook_test_stats.csv
**/*.quarto_ipynb

**/.jupyter_cache

# ignore scripts
scripts/*
1 change: 0 additions & 1 deletion _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ quartodoc:
- EmpiricalSurrogate
- RandomForestSurrogate
- MLPEnsemble
- PiecewiseLinearGPSurrogate

website:
title: "BoFire"
Expand Down
8 changes: 7 additions & 1 deletion bofire/data_models/kernels/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from bofire.data_models.kernels.kernel import AggregationKernel
from bofire.data_models.kernels.molecular import TanimotoKernel
from bofire.data_models.kernels.shape import WassersteinKernel
from bofire.data_models.kernels.shape import ExactWassersteinKernel, WassersteinKernel
from bofire.data_models.priors.api import AnyGeneralPrior, AnyPriorConstraint


Expand All @@ -32,6 +32,8 @@ class AdditiveKernel(AggregationKernel):
IndexKernel,
PositiveIndexKernel,
TanimotoKernel,
WassersteinKernel,
ExactWassersteinKernel,
"AdditiveKernel",
"MultiplicativeKernel",
"ScaleKernel",
Expand All @@ -53,6 +55,8 @@ class MultiplicativeKernel(AggregationKernel):
PositiveIndexKernel,
AdditiveKernel,
TanimotoKernel,
WassersteinKernel,
ExactWassersteinKernel,
"MultiplicativeKernel",
"ScaleKernel",
]
Expand All @@ -74,6 +78,7 @@ class ScaleKernel(AggregationKernel):
TanimotoKernel,
"ScaleKernel",
WassersteinKernel,
ExactWassersteinKernel,
]
outputscale_prior: Optional[AnyGeneralPrior] = None
outputscale_constraint: Optional[AnyPriorConstraint] = None
Expand Down Expand Up @@ -136,6 +141,7 @@ class PolynomialFeatureInteractionKernel(AggregationKernel):
TanimotoKernel,
InfiniteWidthBNNKernel,
WassersteinKernel,
ExactWassersteinKernel,
]
]
max_degree: int
Expand Down
3 changes: 2 additions & 1 deletion bofire/data_models/kernels/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Kernel,
)
from bofire.data_models.kernels.molecular import MolecularKernel, TanimotoKernel
from bofire.data_models.kernels.shape import WassersteinKernel
from bofire.data_models.kernels.shape import ExactWassersteinKernel, WassersteinKernel


AbstractKernel = Union[
Expand Down Expand Up @@ -78,6 +78,7 @@
TanimotoKernel,
InfiniteWidthBNNKernel,
WassersteinKernel,
ExactWassersteinKernel,
WedgeKernel,
]

Expand Down
27 changes: 23 additions & 4 deletions bofire/data_models/kernels/shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Literal, Optional
from typing import List, Literal, Optional

from bofire.data_models.kernels.kernel import Kernel
from bofire.data_models.priors.api import AnyPrior
from bofire.data_models.kernels.kernel import FeatureSpecificKernel
from bofire.data_models.priors.api import AnyPrior, AnyPriorConstraint


class WassersteinKernel(Kernel):
class WassersteinKernel(FeatureSpecificKernel):
"""Kernel based on the Wasserstein distance.

It only works for 1D data that is monotonically increasing, as it is just
Expand All @@ -27,3 +27,22 @@ class WassersteinKernel(Kernel):
type: Literal["WassersteinKernel"] = "WassersteinKernel"
squared: bool = False
lengthscale_prior: Optional[AnyPrior] = None
lengthscale_constraint: Optional[AnyPriorConstraint] = None


class ExactWassersteinKernel(FeatureSpecificKernel):
"""Kernel based on the exact 1D Wasserstein distance for piecewise-linear curves."""

type: Literal["ExactWassersteinKernel"] = "ExactWassersteinKernel"
squared: bool = False
lengthscale_prior: Optional[AnyPrior] = None
lengthscale_constraint: Optional[AnyPriorConstraint] = None
idx_x: List[int]
idx_y: List[int]
prepend_x: List[float] = []
prepend_y: List[float] = []
append_x: List[float] = []
append_y: List[float] = []
normalize_y: float = 1.0
normalize_x: bool = True
order: Literal[1, 2] = 1
11 changes: 6 additions & 5 deletions bofire/data_models/priors/_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def _rebuild_dependent_models() -> None:
RBFKernel,
SphericalLinearKernel,
)
from bofire.data_models.kernels.shape import WassersteinKernel
from bofire.data_models.kernels.shape import (
ExactWassersteinKernel,
WassersteinKernel,
)
from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates
from bofire.data_models.surrogates.linear import LinearSurrogate
from bofire.data_models.surrogates.mixed_single_task_gp import (
Expand All @@ -37,7 +40,6 @@ def _rebuild_dependent_models() -> None:
from bofire.data_models.surrogates.robust_single_task_gp import (
RobustSingleTaskGPSurrogate,
)
from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.data_models.surrogates.single_task_gp import (
SingleTaskGPHyperconfig,
SingleTaskGPSurrogate,
Expand All @@ -61,12 +63,11 @@ def _rebuild_dependent_models() -> None:
(WedgeKernel, "angle_prior"),
(WedgeKernel, "radius_prior"),
(WassersteinKernel, "lengthscale_prior"),
(ExactWassersteinKernel, "lengthscale_prior"),
(SingleTaskGPSurrogate, "noise_prior"),
(MultiTaskGPSurrogate, "noise_prior"),
(MixedSingleTaskGPSurrogate, "noise_prior"),
(TanimotoGPSurrogate, "noise_prior"),
(PiecewiseLinearGPSurrogate, "outputscale_prior"),
(PiecewiseLinearGPSurrogate, "noise_prior"),
(PolynomialSurrogate, "noise_prior"),
(LinearSurrogate, "noise_prior"),
(RobustSingleTaskGPSurrogate, "noise_prior"),
Expand Down Expand Up @@ -104,6 +105,7 @@ def _rebuild_dependent_models() -> None:
IndexKernel,
PositiveIndexKernel,
WassersteinKernel,
ExactWassersteinKernel,
]:
cls.model_rebuild(force=True)

Expand All @@ -126,7 +128,6 @@ def _rebuild_dependent_models() -> None:
MultiTaskGPSurrogate,
MixedSingleTaskGPSurrogate,
TanimotoGPSurrogate,
PiecewiseLinearGPSurrogate,
PolynomialSurrogate,
LinearSurrogate,
RobustSingleTaskGPSurrogate,
Expand Down
4 changes: 0 additions & 4 deletions bofire/data_models/surrogates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
ScalerEnum,
Standardize,
)
from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.data_models.surrogates.single_task_gp import (
SingleTaskGPHyperconfig,
SingleTaskGPSurrogate,
Expand Down Expand Up @@ -79,7 +78,6 @@
CategoricalDeterministicSurrogate,
MultiTaskGPSurrogate,
SingleTaskIBNNSurrogate,
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
]
Expand All @@ -96,7 +94,6 @@
PolynomialSurrogate,
SingleTaskIBNNSurrogate,
TanimotoGPSurrogate,
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
]
Expand All @@ -115,7 +112,6 @@
LinearDeterministicSurrogate,
MultiTaskGPSurrogate,
SingleTaskIBNNSurrogate,
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
]
Expand Down
2 changes: 0 additions & 2 deletions bofire/data_models/surrogates/botorch_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.data_models.surrogates.polynomial import PolynomialSurrogate
from bofire.data_models.surrogates.random_forest import RandomForestSurrogate
from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate
from bofire.data_models.types import InputTransformSpecs
Expand All @@ -49,7 +48,6 @@
LinearDeterministicSurrogate,
CategoricalDeterministicSurrogate,
MultiTaskGPSurrogate,
PiecewiseLinearGPSurrogate,
AdditiveMapSaasSingleTaskGPSurrogate,
EnsembleMapSaasSingleTaskGPSurrogate,
]
Expand Down
183 changes: 0 additions & 183 deletions bofire/data_models/surrogates/shape.py

This file was deleted.

Loading
Loading