diff --git a/_quarto.yml b/_quarto.yml index b1fcc1571..d3ed68092 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -160,6 +160,8 @@ website: href: docs/userguides/surrogates.qmd - text: "Data Models vs Functionals" href: docs/userguides/data_models_functionals.qmd + - text: "Registering Custom Types" + href: docs/userguides/custom_types.qmd - text: "Examples" menu: - text: "Basic Examples" @@ -192,6 +194,7 @@ website: - docs/userguides/strategies.qmd - docs/userguides/surrogates.qmd - docs/userguides/data_models_functionals.qmd + - docs/userguides/custom_types.qmd - title: "Examples" contents: diff --git a/bofire/data_models/_register_utils.py b/bofire/data_models/_register_utils.py new file mode 100644 index 000000000..bcc698974 --- /dev/null +++ b/bofire/data_models/_register_utils.py @@ -0,0 +1,72 @@ +"""Shared utilities for dynamic Pydantic model registration.""" + +import typing +from collections.abc import Sequence +from typing import Optional, Union + + +def patch_field(model_cls: type, field_name: str, new_union: type) -> None: + """Patch a Pydantic model field annotation with a new union type. + + Handles three annotation forms: + - ``Union[A, B, ...]`` — replaced with *new_union* + - ``Optional[Union[A, B, ...]]`` — wrapped as ``Optional[new_union]`` + - ``Sequence[Union[A, B, ...]]`` — wrapped as ``Sequence[new_union]`` + """ + old = model_cls.model_fields[ # ty: ignore[unresolved-attribute] + field_name + ].annotation + args = typing.get_args(old) + + if not args: + new = new_union + elif type(None) in args: + # Optional[X] is Union[X, None] + new = Optional[new_union] + elif typing.get_origin(old) in (list, Sequence): + # Sequence[Union[...]] or list[Union[...]] + new = Sequence[new_union] + else: + new = new_union + + model_cls.__annotations__[field_name] = new + model_cls.model_fields[ # ty: ignore[unresolved-attribute] + field_name + ].annotation = new + + +def append_to_union_field(model_cls: type, field_name: str, new_type: type) -> None: + """Append a type to the union inside a model field annotation. + + Detects the annotation structure (plain ``Union``, ``Optional[Union]``, + or ``Sequence[Union]``) and appends *new_type* to the inner union. + """ + old = model_cls.model_fields[ # ty: ignore[unresolved-attribute] + field_name + ].annotation + origin = typing.get_origin(old) + + if origin in (list, Sequence): + # Sequence[Union[A, B, ...]] → Sequence[Union[A, B, ..., new_type]] + inner = typing.get_args(old)[0] + inner_args = typing.get_args(inner) + if new_type not in inner_args: + new_inner = Union[tuple(list(inner_args) + [new_type])] + new_ann = Sequence[new_inner] + model_cls.__annotations__[field_name] = new_ann + model_cls.model_fields[ # ty: ignore[unresolved-attribute] + field_name + ].annotation = new_ann + else: + # Union[A, B, ...] or Optional[Union[A, B, ...]] + args = typing.get_args(old) + has_none = type(None) in args + non_none = [a for a in args if a is not type(None)] + if new_type not in non_none: + non_none.append(new_type) + new_union = Union[tuple(non_none)] + new_ann = Optional[new_union] if has_none else new_union + model_cls.__annotations__[field_name] = new_ann + model_cls.model_fields[ # ty: ignore[unresolved-attribute] + field_name + ].annotation = new_ann diff --git a/bofire/data_models/features/_register.py b/bofire/data_models/features/_register.py new file mode 100644 index 000000000..bcc3bd9e8 --- /dev/null +++ b/bofire/data_models/features/_register.py @@ -0,0 +1,29 @@ +"""Registration utilities for custom engineered feature types.""" + +from typing import Union + + +def register_engineered_feature(data_model_cls: type) -> None: + """Register a custom engineered feature type so it is accepted in EngineeredFeatures. + + This appends the type to the internal registry, rebuilds the + ``AnyEngineeredFeature`` union, and calls ``model_rebuild`` on + ``EngineeredFeatures`` so that Pydantic accepts the new type. + + Args: + data_model_cls: A concrete subclass of ``EngineeredFeature``. + """ + import bofire.data_models.features.api as features_api + + if data_model_cls in features_api._ENGINEERED_FEATURE_TYPES: + return + features_api._ENGINEERED_FEATURE_TYPES.append(data_model_cls) + features_api.AnyEngineeredFeature = Union[ + tuple(features_api._ENGINEERED_FEATURE_TYPES) + ] + + from bofire.data_models._register_utils import append_to_union_field + from bofire.data_models.domain.features import EngineeredFeatures + + append_to_union_field(EngineeredFeatures, "features", data_model_cls) + EngineeredFeatures.model_rebuild(force=True) diff --git a/bofire/data_models/features/api.py b/bofire/data_models/features/api.py index cb6ca6ad1..f8dbd8ee3 100644 --- a/bofire/data_models/features/api.py +++ b/bofire/data_models/features/api.py @@ -1,5 +1,6 @@ from typing import Union +from bofire.data_models.features._register import register_engineered_feature from bofire.data_models.features.categorical import CategoricalInput, CategoricalOutput from bofire.data_models.features.continuous import ContinuousInput, ContinuousOutput from bofire.data_models.features.descriptor import ( @@ -71,7 +72,7 @@ AnyOutput = Union[ContinuousOutput, CategoricalOutput] -AnyEngineeredFeature = Union[ +_ENGINEERED_FEATURE_TYPES: list[type[EngineeredFeature]] = [ SumFeature, MeanFeature, WeightedMeanFeature, @@ -82,3 +83,5 @@ InterpolateFeature, CloneFeature, ] + +AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)] diff --git a/bofire/data_models/kernels/_register.py b/bofire/data_models/kernels/_register.py new file mode 100644 index 000000000..8af0f4ac9 --- /dev/null +++ b/bofire/data_models/kernels/_register.py @@ -0,0 +1,124 @@ +"""Registration utilities for custom kernel types.""" + +from typing import Union + + +def _rebuild_dependent_models(new_kernel_cls: type) -> None: + """Rebuild all Pydantic models whose fields reference kernel unions.""" + import bofire.data_models.kernels.api as kernels_api + from bofire.data_models._register_utils import append_to_union_field, patch_field + from bofire.data_models.kernels.aggregation import ( + AdditiveKernel, + MultiplicativeKernel, + PolynomialFeatureInteractionKernel, + ScaleKernel, + ) + from bofire.data_models.kernels.categorical import CategoricalKernel + from bofire.data_models.kernels.conditional import ( + ConditionalEmbeddingKernel, + WedgeKernel, + ) + from bofire.data_models.kernels.continuous import ContinuousKernel + from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates + from bofire.data_models.surrogates.mixed_single_task_gp import ( + MixedSingleTaskGPSurrogate, + ) + from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate + from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate + from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate + + # Add new kernel type to aggregation kernel inline unions + # (handles both Sequence[Union[...]] and plain Union[...]) + for model_cls, field_name in [ + (AdditiveKernel, "kernels"), + (MultiplicativeKernel, "kernels"), + (PolynomialFeatureInteractionKernel, "kernels"), + (ScaleKernel, "base_kernel"), + (ConditionalEmbeddingKernel, "base_kernel"), + (WedgeKernel, "base_kernel"), + ]: + append_to_union_field(model_cls, field_name, new_kernel_cls) + + # Rebuild aggregation and conditional kernels + for cls in [ + AdditiveKernel, + MultiplicativeKernel, + ScaleKernel, + PolynomialFeatureInteractionKernel, + ConditionalEmbeddingKernel, + WedgeKernel, + ]: + cls.model_rebuild(force=True) + + # Patch AnyKernel fields on surrogate models + for model_cls, field_name in [ + (SingleTaskGPSurrogate, "kernel"), + (MultiTaskGPSurrogate, "kernel"), + (TanimotoGPSurrogate, "kernel"), + ]: + patch_field(model_cls, field_name, kernels_api.AnyKernel) + + # Patch sub-category kernel fields if the new type is a subclass + if issubclass(new_kernel_cls, ContinuousKernel): + patch_field( + MixedSingleTaskGPSurrogate, + "continuous_kernel", + kernels_api.AnyContinuousKernel, + ) + if issubclass(new_kernel_cls, CategoricalKernel): + patch_field( + MixedSingleTaskGPSurrogate, + "categorical_kernel", + kernels_api.AnyCategoricalKernel, + ) + + # Rebuild surrogate models + for cls in [ + SingleTaskGPSurrogate, + MultiTaskGPSurrogate, + TanimotoGPSurrogate, + MixedSingleTaskGPSurrogate, + ]: + cls.model_rebuild(force=True) + + # Rebuild BotorchSurrogates + BotorchSurrogates.model_rebuild(force=True) + + +def register_kernel(data_model_cls: type) -> None: + """Register a custom kernel type so it is accepted in AnyKernel fields. + + This appends the type to the internal registry, rebuilds the + ``AnyKernel`` union, and calls ``model_rebuild`` on all dependent + Pydantic models (aggregation kernels, surrogates) so that the new + type is accepted. + + If the type is a subclass of ``ContinuousKernel`` or ``CategoricalKernel``, + it is also added to the corresponding sub-category union + (``AnyContinuousKernel`` / ``AnyCategoricalKernel``). + + Args: + data_model_cls: A concrete subclass of ``Kernel``. + """ + import bofire.data_models.kernels.api as kernels_api + from bofire.data_models.kernels.categorical import CategoricalKernel + from bofire.data_models.kernels.continuous import ContinuousKernel + + if data_model_cls in kernels_api._KERNEL_TYPES: + return + kernels_api._KERNEL_TYPES.append(data_model_cls) + kernels_api.AnyKernel = Union[tuple(kernels_api._KERNEL_TYPES)] + + # Auto-detect sub-category from base class + if issubclass(data_model_cls, ContinuousKernel): + kernels_api._CONTINUOUS_KERNEL_TYPES.append(data_model_cls) + kernels_api.AnyContinuousKernel = Union[ + tuple(kernels_api._CONTINUOUS_KERNEL_TYPES) + ] + elif issubclass(data_model_cls, CategoricalKernel): + kernels_api._CATEGORICAL_KERNEL_TYPES.append(data_model_cls) + kernels_api.AnyCategoricalKernel = Union[ + tuple(kernels_api._CATEGORICAL_KERNEL_TYPES) + ] + + _rebuild_dependent_models(data_model_cls) diff --git a/bofire/data_models/kernels/api.py b/bofire/data_models/kernels/api.py index 96074e735..14f27c041 100644 --- a/bofire/data_models/kernels/api.py +++ b/bofire/data_models/kernels/api.py @@ -1,5 +1,6 @@ from typing import Union +from bofire.data_models.kernels._register import register_kernel # noqa: F401 from bofire.data_models.kernels.aggregation import ( AdditiveKernel, MultiplicativeKernel, @@ -40,7 +41,7 @@ AggregationKernel, ] -AnyContinuousKernel = Union[ +_CONTINUOUS_KERNEL_TYPES: list[type[ContinuousKernel]] = [ MaternKernel, LinearKernel, PolynomialKernel, @@ -49,15 +50,19 @@ InfiniteWidthBNNKernel, ] -AnyCategoricalKernel = Union[ +AnyContinuousKernel = Union[tuple(_CONTINUOUS_KERNEL_TYPES)] + +_CATEGORICAL_KERNEL_TYPES: list[type[CategoricalKernel]] = [ HammingDistanceKernel, IndexKernel, PositiveIndexKernel, ] +AnyCategoricalKernel = Union[tuple(_CATEGORICAL_KERNEL_TYPES)] + AnyMolecularKernel = TanimotoKernel -AnyKernel = Union[ +_KERNEL_TYPES: list[type[Kernel]] = [ AdditiveKernel, MultiplicativeKernel, PolynomialFeatureInteractionKernel, @@ -75,3 +80,5 @@ WassersteinKernel, WedgeKernel, ] + +AnyKernel = Union[tuple(_KERNEL_TYPES)] diff --git a/bofire/data_models/priors/_register.py b/bofire/data_models/priors/_register.py new file mode 100644 index 000000000..53c76b04c --- /dev/null +++ b/bofire/data_models/priors/_register.py @@ -0,0 +1,169 @@ +"""Registration utilities for custom prior and prior constraint types.""" + +from typing import Union + + +def _rebuild_dependent_models() -> None: + """Rebuild all Pydantic models whose fields reference AnyPrior or AnyPriorConstraint.""" + import bofire.data_models.priors.api as priors_api + from bofire.data_models._register_utils import patch_field + + # Lazy imports to avoid circular dependencies + from bofire.data_models.kernels.aggregation import ( + AdditiveKernel, + MultiplicativeKernel, + PolynomialFeatureInteractionKernel, + ScaleKernel, + ) + from bofire.data_models.kernels.categorical import ( + HammingDistanceKernel, + IndexKernel, + PositiveIndexKernel, + ) + from bofire.data_models.kernels.conditional import WedgeKernel + from bofire.data_models.kernels.continuous import ( + MaternKernel, + RBFKernel, + SphericalLinearKernel, + ) + from bofire.data_models.kernels.shape import WassersteinKernel + from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates + from bofire.data_models.surrogates.linear import LinearSurrogate + from bofire.data_models.surrogates.mixed_single_task_gp import ( + MixedSingleTaskGPSurrogate, + ) + from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate + from bofire.data_models.surrogates.polynomial import PolynomialSurrogate + from bofire.data_models.surrogates.robust_single_task_gp import ( + RobustSingleTaskGPSurrogate, + ) + from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate + from bofire.data_models.surrogates.single_task_gp import ( + SingleTaskGPHyperconfig, + SingleTaskGPSurrogate, + ) + from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate + + AnyPrior = priors_api.AnyPrior + AnyPriorConstraint = priors_api.AnyPriorConstraint + + # Patch AnyPrior fields + for model_cls, field_name in [ + (RBFKernel, "lengthscale_prior"), + (MaternKernel, "lengthscale_prior"), + (SphericalLinearKernel, "lengthscale_prior"), + (HammingDistanceKernel, "lengthscale_prior"), + (IndexKernel, "prior"), + (PositiveIndexKernel, "prior"), + (PositiveIndexKernel, "task_prior"), + (PositiveIndexKernel, "diag_prior"), + (WedgeKernel, "lengthscale_prior"), + (WedgeKernel, "angle_prior"), + (WedgeKernel, "radius_prior"), + (WassersteinKernel, "lengthscale_prior"), + (SingleTaskGPSurrogate, "noise_prior"), + (MultiTaskGPSurrogate, "noise_prior"), + (MixedSingleTaskGPSurrogate, "noise_prior"), + (TanimotoGPSurrogate, "noise_prior"), + (PiecewiseLinearGPSurrogate, "outputscale_prior"), + (PiecewiseLinearGPSurrogate, "noise_prior"), + (PolynomialSurrogate, "noise_prior"), + (LinearSurrogate, "noise_prior"), + (RobustSingleTaskGPSurrogate, "noise_prior"), + ]: + patch_field(model_cls, field_name, AnyPrior) + + # Patch AnyPriorConstraint fields + for model_cls, field_name in [ + (RBFKernel, "lengthscale_constraint"), + (MaternKernel, "lengthscale_constraint"), + (SphericalLinearKernel, "lengthscale_constraint"), + (HammingDistanceKernel, "lengthscale_constraint"), + (IndexKernel, "var_constraint"), + (PositiveIndexKernel, "var_constraint"), + (WedgeKernel, "lengthscale_constraint"), + (ScaleKernel, "outputscale_constraint"), + (SingleTaskGPHyperconfig, "lengthscale_constraint"), + (SingleTaskGPHyperconfig, "outputscale_constraint"), + ]: + patch_field(model_cls, field_name, AnyPriorConstraint) + + # Rebuild in dependency order: + # 1. Leaf kernel models + for cls in [ + RBFKernel, + MaternKernel, + SphericalLinearKernel, + HammingDistanceKernel, + IndexKernel, + PositiveIndexKernel, + WassersteinKernel, + ]: + cls.model_rebuild(force=True) + + # 2. Conditional kernels (reference leaf kernels) + WedgeKernel.model_rebuild(force=True) + + # 3. Aggregation kernels (embed leaf kernel types) + for cls in [ + AdditiveKernel, + MultiplicativeKernel, + ScaleKernel, + PolynomialFeatureInteractionKernel, + ]: + cls.model_rebuild(force=True) + + # 4. Surrogate models + SingleTaskGPHyperconfig.model_rebuild(force=True) + for cls in [ + SingleTaskGPSurrogate, + MultiTaskGPSurrogate, + MixedSingleTaskGPSurrogate, + TanimotoGPSurrogate, + PiecewiseLinearGPSurrogate, + PolynomialSurrogate, + LinearSurrogate, + RobustSingleTaskGPSurrogate, + ]: + cls.model_rebuild(force=True) + + # 5. BotorchSurrogates + BotorchSurrogates.model_rebuild(force=True) + + +def register_prior(data_model_cls: type) -> None: + """Register a custom prior type so it is accepted in AnyPrior fields. + + This appends the type to the internal registry, rebuilds the + ``AnyPrior`` union, and calls ``model_rebuild`` on all dependent + Pydantic models (kernels, surrogates) so that the new type is accepted. + + Args: + data_model_cls: A concrete subclass of ``Prior``. + """ + import bofire.data_models.priors.api as priors_api + + if data_model_cls in priors_api._PRIOR_TYPES: + return + priors_api._PRIOR_TYPES.append(data_model_cls) + priors_api.AnyPrior = Union[tuple(priors_api._PRIOR_TYPES)] + _rebuild_dependent_models() + + +def register_prior_constraint(data_model_cls: type) -> None: + """Register a custom prior constraint type so it is accepted in AnyPriorConstraint fields. + + This appends the type to the internal registry, rebuilds the + ``AnyPriorConstraint`` union, and calls ``model_rebuild`` on all + dependent Pydantic models. + + Args: + data_model_cls: A concrete subclass of ``PriorConstraint`` or ``Interval``. + """ + import bofire.data_models.priors.api as priors_api + + if data_model_cls in priors_api._PRIOR_CONSTRAINT_TYPES: + return + priors_api._PRIOR_CONSTRAINT_TYPES.append(data_model_cls) + priors_api.AnyPriorConstraint = Union[tuple(priors_api._PRIOR_CONSTRAINT_TYPES)] + _rebuild_dependent_models() diff --git a/bofire/data_models/priors/api.py b/bofire/data_models/priors/api.py index 02540c80f..2ce33ebc9 100644 --- a/bofire/data_models/priors/api.py +++ b/bofire/data_models/priors/api.py @@ -1,6 +1,10 @@ from functools import partial from typing import Union +from bofire.data_models.priors._register import ( + register_prior, + register_prior_constraint, +) from bofire.data_models.priors.constraint import ( GreaterThan, LessThan, @@ -25,7 +29,7 @@ AbstractPrior = Prior AbstractPriorConstraint = Union[PriorConstraint, Interval] -AnyPrior = Union[ +_PRIOR_TYPES: list[type[Prior]] = [ GammaPrior, NormalPrior, LKJPrior, @@ -33,10 +37,19 @@ DimensionalityScaledLogNormalPrior, ] -AnyPriorConstraint = Union[ - NonTransformedInterval, LogTransformedInterval, Positive, GreaterThan, LessThan +AnyPrior = Union[tuple(_PRIOR_TYPES)] + +_PRIOR_CONSTRAINT_TYPES: list[type] = [ + NonTransformedInterval, + LogTransformedInterval, + Positive, + GreaterThan, + LessThan, ] +AnyPriorConstraint = Union[tuple(_PRIOR_CONSTRAINT_TYPES)] + + # these are priors that are generally applicable # and do not depend on problem specific extra parameters AnyGeneralPrior = Union[GammaPrior, NormalPrior, LKJPrior, LogNormalPrior] diff --git a/bofire/data_models/strategies/_register.py b/bofire/data_models/strategies/_register.py new file mode 100644 index 000000000..c993f1399 --- /dev/null +++ b/bofire/data_models/strategies/_register.py @@ -0,0 +1,28 @@ +"""Registration utilities for custom strategy types.""" + +from typing import Union + + +def register_strategy(data_model_cls: type) -> None: + """Register a custom strategy type so it is accepted in ActualStrategy fields. + + This appends the type to the internal registry, rebuilds the + ``ActualStrategy`` union, and calls ``model_rebuild`` on the + ``Step`` and ``StepwiseStrategy`` models so that Pydantic accepts the + new type. + + Args: + data_model_cls: A concrete subclass of ``Strategy``. + """ + import bofire.data_models.strategies.actual_strategy_type as ast_mod + from bofire.data_models._register_utils import patch_field + from bofire.data_models.strategies.stepwise.stepwise import Step, StepwiseStrategy + + if data_model_cls in ast_mod._ACTUAL_STRATEGY_TYPES: + return + ast_mod._ACTUAL_STRATEGY_TYPES.append(data_model_cls) + ast_mod.ActualStrategy = Union[tuple(ast_mod._ACTUAL_STRATEGY_TYPES)] + + patch_field(Step, "strategy_data", ast_mod.ActualStrategy) + Step.model_rebuild(force=True) + StepwiseStrategy.model_rebuild(force=True) diff --git a/bofire/data_models/strategies/actual_strategy_type.py b/bofire/data_models/strategies/actual_strategy_type.py index ca4b21b21..4a003b68a 100644 --- a/bofire/data_models/strategies/actual_strategy_type.py +++ b/bofire/data_models/strategies/actual_strategy_type.py @@ -23,9 +23,10 @@ ) from bofire.data_models.strategies.random import RandomStrategy from bofire.data_models.strategies.shortest_path import ShortestPathStrategy +from bofire.data_models.strategies.strategy import Strategy -ActualStrategy = Union[ +_ACTUAL_STRATEGY_TYPES: list[type[Strategy]] = [ SoboStrategy, AdditiveSoboStrategy, ActiveLearningStrategy, @@ -42,3 +43,5 @@ ShortestPathStrategy, FractionalFactorialStrategy, ] + +ActualStrategy = Union[tuple(_ACTUAL_STRATEGY_TYPES)] diff --git a/bofire/data_models/strategies/api.py b/bofire/data_models/strategies/api.py index 3d96adc7f..9647806aa 100644 --- a/bofire/data_models/strategies/api.py +++ b/bofire/data_models/strategies/api.py @@ -1,5 +1,6 @@ from typing import Union +from bofire.data_models.strategies._register import register_strategy from bofire.data_models.strategies.actual_strategy_type import ActualStrategy from bofire.data_models.strategies.doe import ( AnyDoEOptimalityCriterion, diff --git a/bofire/data_models/surrogates/api.py b/bofire/data_models/surrogates/api.py index 7a98b97af..664f798aa 100644 --- a/bofire/data_models/surrogates/api.py +++ b/bofire/data_models/surrogates/api.py @@ -5,6 +5,7 @@ from bofire.data_models.surrogates.botorch_surrogates import ( AnyBotorchSurrogate, BotorchSurrogates, + register_botorch_surrogate, ) from bofire.data_models.surrogates.deterministic import ( CategoricalDeterministicSurrogate, diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index bffb016e1..c460470bf 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -1,10 +1,11 @@ import itertools -from typing import List, Union +from typing import List, Type, Union from pydantic import field_validator from bofire.data_models.base import BaseModel from bofire.data_models.domain.api import Inputs, Outputs +from bofire.data_models.surrogates.botorch import BotorchSurrogate from bofire.data_models.surrogates.deterministic import ( CategoricalDeterministicSurrogate, LinearDeterministicSurrogate, @@ -34,7 +35,7 @@ from bofire.data_models.types import InputTransformSpecs -AnyBotorchSurrogate = Union[ +_BOTORCH_SURROGATE_TYPES: List[Type[BotorchSurrogate]] = [ EmpiricalSurrogate, RandomForestSurrogate, SingleTaskGPSurrogate, @@ -53,6 +54,31 @@ EnsembleMapSaasSingleTaskGPSurrogate, ] +AnyBotorchSurrogate = Union[tuple(_BOTORCH_SURROGATE_TYPES)] + + +def register_botorch_surrogate( + data_model_cls: Type[BotorchSurrogate], +) -> None: + """Register a custom BotorchSurrogate type so it is accepted by BotorchSurrogates. + + This appends the type to the internal registry, rebuilds the + ``AnyBotorchSurrogate`` union, and calls ``model_rebuild`` on + ``BotorchSurrogates`` so that Pydantic picks up the new type. + + Args: + data_model_cls: A concrete subclass of ``BotorchSurrogate``. + """ + global AnyBotorchSurrogate + if data_model_cls in _BOTORCH_SURROGATE_TYPES: + return + _BOTORCH_SURROGATE_TYPES.append(data_model_cls) + AnyBotorchSurrogate = Union[tuple(_BOTORCH_SURROGATE_TYPES)] + new_annotation = List[AnyBotorchSurrogate] + BotorchSurrogates.__annotations__["surrogates"] = new_annotation + BotorchSurrogates.model_fields["surrogates"].annotation = new_annotation + BotorchSurrogates.model_rebuild(force=True) + class BotorchSurrogates(BaseModel): """ "List of botorch surrogates. diff --git a/bofire/kernels/api.py b/bofire/kernels/api.py index a395b6f89..3340a9d39 100644 --- a/bofire/kernels/api.py +++ b/bofire/kernels/api.py @@ -1 +1 @@ -from bofire.kernels.mapper import map # noqa: F401 +from bofire.kernels.mapper import map, register # noqa: F401 diff --git a/bofire/kernels/mapper.py b/bofire/kernels/mapper.py index 2e174368a..b323b6997 100644 --- a/bofire/kernels/mapper.py +++ b/bofire/kernels/mapper.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Type import gpytorch import torch @@ -21,6 +21,47 @@ from bofire.kernels.spherical_kernels import SphericalLinearKernel +def register( + data_model_cls: Type[data_models.Kernel], + map_fn: Optional[Callable] = None, +): + """Register a custom kernel mapping from data model to factory function. + + Can be used as a decorator or as a direct function call:: + + # Decorator form + @register(MyKernelDataModel) + def map_my_kernel(data_model, batch_shape, active_dims, features_to_idx_mapper): + return MyGpytorchKernel(...) + + # Direct call form + register(MyKernelDataModel, map_my_kernel) + + Args: + data_model_cls: The Pydantic data model class. + map_fn: A callable that takes ``(data_model, batch_shape, active_dims, + features_to_idx_mapper)`` and returns a gpytorch kernel. If not + provided, returns a decorator. + + Returns: + The mapping function (unchanged) when used as a decorator, None otherwise. + """ + + def _register(fn: Callable) -> Callable: + KERNEL_MAP[data_model_cls] = fn + + # Also register with the data model union so Pydantic accepts the type + data_models.register_kernel(data_model_cls) + + return fn + + if map_fn is not None: + _register(map_fn) + return None + + return _register + + def _compute_active_dims( data_model: data_models.FeatureSpecificKernel, active_dims: List[int], @@ -409,22 +450,22 @@ def map_SphericalLinearKernel( KERNEL_MAP = { - data_models.WassersteinKernel: map_WassersteinKernel, data_models.RBFKernel: map_RBFKernel, data_models.MaternKernel: map_MaternKernel, + data_models.InfiniteWidthBNNKernel: map_InfiniteWidthBNNKernel, data_models.LinearKernel: map_LinearKernel, data_models.PolynomialKernel: map_PolynomialKernel, data_models.AdditiveKernel: map_AdditiveKernel, data_models.MultiplicativeKernel: map_MultiplicativeKernel, data_models.ScaleKernel: map_ScaleKernel, - data_models.SphericalLinearKernel: map_SphericalLinearKernel, data_models.TanimotoKernel: map_TanimotoKernel, data_models.HammingDistanceKernel: map_HammingDistanceKernel, data_models.IndexKernel: map_IndexKernel, data_models.PositiveIndexKernel: map_PositiveIndexKernel, - data_models.InfiniteWidthBNNKernel: map_InfiniteWidthBNNKernel, + data_models.WassersteinKernel: map_WassersteinKernel, data_models.PolynomialFeatureInteractionKernel: map_PolynomialFeatureInteractionKernel, data_models.WedgeKernel: map_WedgeKernel, + data_models.SphericalLinearKernel: map_SphericalLinearKernel, } diff --git a/bofire/priors/api.py b/bofire/priors/api.py index 6311f6b89..575ed9bab 100644 --- a/bofire/priors/api.py +++ b/bofire/priors/api.py @@ -1 +1 @@ -from bofire.priors.mapper import map +from bofire.priors.mapper import map, register diff --git a/bofire/priors/mapper.py b/bofire/priors/mapper.py index e19f22c3a..733a9c82e 100644 --- a/bofire/priors/mapper.py +++ b/bofire/priors/mapper.py @@ -1,5 +1,5 @@ import math -from typing import Union +from typing import Callable, Optional, Type, Union import gpytorch from botorch.utils.constraints import LogTransformedInterval, NonTransformedInterval @@ -13,6 +13,51 @@ ] +def register( + data_model_cls: Type, + map_fn: Optional[Callable] = None, +): + """Register a custom prior/constraint mapping from data model to factory function. + + Can be used as a decorator or as a direct function call:: + + # Decorator form + @register(MyPriorDataModel) + def map_my_prior(data_model, **kwargs): + return MyGpytorchPrior(...) + + # Direct call form + register(MyPriorDataModel, map_my_prior) + + Args: + data_model_cls: The Pydantic data model class. + map_fn: A callable that takes ``(data_model, **kwargs)`` and returns a + gpytorch prior or constraint. If not provided, returns a decorator. + + Returns: + The mapping function (unchanged) when used as a decorator, None otherwise. + """ + + def _register(fn: Callable) -> Callable: + PRIOR_MAP[data_model_cls] = fn + + # Also register with the data model unions so Pydantic accepts the type + if issubclass(data_model_cls, data_models.Prior): + data_models.register_prior(data_model_cls) + elif issubclass( + data_model_cls, (data_models.PriorConstraint, data_models.Interval) + ): + data_models.register_prior_constraint(data_model_cls) + + return fn + + if map_fn is not None: + _register(map_fn) + return None + + return _register + + def map_NormalPrior( data_model: data_models.NormalPrior, **kwargs, diff --git a/bofire/strategies/api.py b/bofire/strategies/api.py index 7cad2a275..a5785b15a 100644 --- a/bofire/strategies/api.py +++ b/bofire/strategies/api.py @@ -1,6 +1,6 @@ from bofire.strategies.doe_strategy import DoEStrategy from bofire.strategies.fractional_factorial import FractionalFactorialStrategy -from bofire.strategies.mapper import map +from bofire.strategies.mapper import map, register from bofire.strategies.predictives.acqf_optimization import ( AcquisitionOptimizer, get_optimizer, diff --git a/bofire/strategies/mapper.py b/bofire/strategies/mapper.py index da2d85ff1..97890ab88 100644 --- a/bofire/strategies/mapper.py +++ b/bofire/strategies/mapper.py @@ -1,9 +1,51 @@ +from typing import Optional, Type + import bofire.data_models.strategies.api as data_models from bofire.strategies.mapper_actual import STRATEGY_MAP as ACTUAL_MAP from bofire.strategies.mapper_meta import STRATEGY_MAP as META_MAP from bofire.strategies.strategy import Strategy +def register( + data_model_cls: Type[data_models.Strategy], + strategy_cls: Optional[Type[Strategy]] = None, +): + """Register a custom strategy mapping from data model to functional class. + + Can be used as a decorator or as a direct function call:: + + # Decorator form + @register(MyDataModel) + class MyStrategy(Strategy): + ... + + # Direct call form + register(MyDataModel, MyStrategy) + + Args: + data_model_cls: The Pydantic data model class. + strategy_cls: The functional strategy class. If not provided, + returns a decorator. + + Returns: + The strategy class (unchanged) when used as a decorator, None otherwise. + """ + + def _register(cls: Type[Strategy]) -> Type[Strategy]: + ACTUAL_MAP[data_model_cls] = cls + + # Also register with the data model union so Pydantic accepts the type + data_models.register_strategy(data_model_cls) + + return cls + + if strategy_cls is not None: + _register(strategy_cls) + return None + + return _register + + def map(data_model: data_models.Strategy) -> Strategy: data_cls = data_model.__class__ if data_cls in META_MAP: diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index a9e36d2a0..dd3b149df 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -5,7 +5,7 @@ AdditiveMapSaasSingleTaskGPSurrogate, EnsembleMapSaasSingleTaskGPSurrogate, ) -from bofire.surrogates.mapper import map +from bofire.surrogates.mapper import map, register from bofire.surrogates.mlp import ( ClassificationMLPEnsemble, MLPEnsemble, diff --git a/bofire/surrogates/engineered_features.py b/bofire/surrogates/engineered_features.py index 8ec598d11..cfcea1122 100644 --- a/bofire/surrogates/engineered_features.py +++ b/bofire/surrogates/engineered_features.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable +from typing import Callable, Optional, Type import pandas as pd import torch @@ -22,6 +22,49 @@ from bofire.utils.torch_tools import interp1d +def register( + data_model_cls: Type[EngineeredFeature], + map_fn: Optional[Callable] = None, +): + """Register a custom engineered feature mapping from data model to factory function. + + Can be used as a decorator or as a direct function call:: + + # Decorator form + @register(MyEngineeredFeatureDataModel) + def map_my_feature(inputs, transform_specs, feature): + return AppendFeatures(...) + + # Direct call form + register(MyEngineeredFeatureDataModel, map_my_feature) + + Args: + data_model_cls: The Pydantic data model class. + map_fn: A callable that takes ``(inputs, transform_specs, feature)`` + and returns a ``botorch.models.transforms.input.AppendFeatures`` + instance. If not provided, returns a decorator. + + Returns: + The mapping function (unchanged) when used as a decorator, None otherwise. + """ + + def _register(fn: Callable) -> Callable: + AGGREGATE_MAP[data_model_cls] = fn + + # Also register with the data model union so Pydantic accepts the type + from bofire.data_models.features.api import register_engineered_feature + + register_engineered_feature(data_model_cls) + + return fn + + if map_fn is not None: + _register(map_fn) + return None + + return _register + + def _weighted_features( X: torch.Tensor, indices: torch.Tensor, @@ -226,7 +269,6 @@ def clone_features(X: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: ) -# Mapper bindings map_sum_feature = partial(_map_reduction_feature, reducer=torch.sum) map_product_feature = partial(_map_reduction_feature, reducer=torch.prod) map_mean_feature = partial(_map_reduction_feature, reducer=torch.mean) @@ -239,15 +281,14 @@ def clone_features(X: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: _map_molecular_weighted_feature, normalize=True ) - AGGREGATE_MAP = { SumFeature: map_sum_feature, ProductFeature: map_product_feature, MeanFeature: map_mean_feature, - WeightedMeanFeature: map_weighted_mean_feature, WeightedSumFeature: map_weighted_sum_feature, - MolecularWeightedMeanFeature: map_molecular_weighted_mean_feature, + WeightedMeanFeature: map_weighted_mean_feature, MolecularWeightedSumFeature: map_molecular_weighted_sum_feature, + MolecularWeightedMeanFeature: map_molecular_weighted_mean_feature, InterpolateFeature: map_interpolate_feature, CloneFeature: map_clone_feature, } diff --git a/bofire/surrogates/mapper.py b/bofire/surrogates/mapper.py index 34bddaba7..a20bf29de 100644 --- a/bofire/surrogates/mapper.py +++ b/bofire/surrogates/mapper.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Type +from typing import Callable, Dict, Optional, Type from bofire.data_models.kernels.api import ( AdditiveKernel, @@ -118,6 +118,59 @@ def map_TanimotoGPSurrogate( } +def register( + data_model_cls: Type[data_models.Surrogate], + surrogate_cls: Optional[Type[Surrogate]] = None, + data_model_transform: Optional[Callable] = None, +): + """Register a custom surrogate mapping from data model to functional class. + + Can be used as a decorator or as a direct function call:: + + # Decorator form + @register(MyDataModel) + class MySurrogate(Surrogate): + ... + + # Direct call form + register(MyDataModel, MySurrogate) + + If ``data_model_cls`` is a subclass of + :class:`~bofire.data_models.surrogates.botorch.BotorchSurrogate`, it is + also registered with :class:`BotorchSurrogates` so that it can be used + in botorch-based strategies out of the box. + + Args: + data_model_cls: The Pydantic data model class. + surrogate_cls: The functional surrogate class. If not provided, + returns a decorator. + data_model_transform: Optional function that transforms the data model + before instantiation (e.g. to convert to a simpler representation). + + Returns: + The surrogate class (unchanged) when used as a decorator, None otherwise. + """ + + def _register(cls: Type[Surrogate]) -> Type[Surrogate]: + SURROGATE_MAP[data_model_cls] = cls + if data_model_transform is not None: + DATA_MODEL_MAP[data_model_cls] = data_model_transform + + if issubclass(data_model_cls, data_models.BotorchSurrogate): + from bofire.data_models.surrogates.botorch_surrogates import ( + register_botorch_surrogate, + ) + + register_botorch_surrogate(data_model_cls) + return cls + + if surrogate_cls is not None: + _register(surrogate_cls) + return None + + return _register + + def map(data_model: data_models.Surrogate, **kwargs) -> Surrogate: new_data_model = data_model if data_model.__class__ in DATA_MODEL_MAP: diff --git a/docs/userguides/custom_types.qmd b/docs/userguides/custom_types.qmd new file mode 100644 index 000000000..9294ca186 --- /dev/null +++ b/docs/userguides/custom_types.qmd @@ -0,0 +1,219 @@ +# Registering Custom Types + +BoFire ships with a collection of built-in strategies, surrogates, kernels, and priors. If you need a component that is not provided out of the box, you can implement your own and **register** it so that BoFire's mapping and serialization infrastructure works with it seamlessly. + +Registration does two things: + +1. It adds your functional class to the mapper so that `strategies.map()` / `surrogates.map()` can instantiate it. +2. It updates the Pydantic unions used for validation so that your custom data model is accepted wherever the corresponding built-in types are accepted (e.g. in `BotorchSurrogates`, `StepwiseStrategy`, or surrogate kernel fields). + +## Registering a custom strategy + +Define a data model (Pydantic) and a functional class, then call `register`: + +```python +from typing import Literal, Type + +import pandas as pd + +import bofire.strategies.api as strategies +from bofire.data_models.constraints.api import Constraint +from bofire.data_models.features.api import Feature +from bofire.data_models.strategies.strategy import Strategy as StrategyDataModel +from bofire.strategies.strategy import Strategy + + +# 1. Data model — holds configuration, is serializable +class MyStrategyDataModel(StrategyDataModel): + type: Literal["MyStrategy"] = "MyStrategy" + my_param: float = 1.0 + + def is_constraint_implemented(self, my_type: Type[Constraint]) -> bool: + return True + + @classmethod + def is_feature_implemented(cls, my_type: Type[Feature]) -> bool: + return True + + +# 2. Functional class — implements ask/tell +class MyStrategy(Strategy): + def _ask(self, candidate_count): + return pd.DataFrame() + + def has_sufficient_experiments(self) -> bool: + return True + + +# 3. Register (direct call) +strategies.register(MyStrategyDataModel, MyStrategy) +``` + +You can also use the decorator form: + +```python +@strategies.register(MyStrategyDataModel) +class MyStrategy(Strategy): + ... +``` + +After registration, `strategies.map(MyStrategyDataModel(domain=domain))` returns an instance of `MyStrategy`. The custom type is also accepted inside `StepwiseStrategy` steps. + +## Registering a custom surrogate + +Surrogate registration follows the same pattern via `surrogates.register()`: + +```python +from typing import Literal, Type + +import bofire.surrogates.api as surrogates +from bofire.data_models.features.api import AnyOutput, ContinuousOutput +from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate +from bofire.surrogates.botorch import TrainableBotorchSurrogate as BotorchImpl + + +# 1. Data model +class MySurrogateDataModel(TrainableBotorchSurrogate): + type: Literal["MySurrogate"] = "MySurrogate" + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + return isinstance(my_type, type(ContinuousOutput)) + + +# 2. Functional class +class MySurrogate(BotorchImpl): + def _fit(self, train_X, train_Y, **kwargs): + ... + + def _predict(self, transformed_X): + ... + + +# 3. Register +surrogates.register(MySurrogateDataModel, MySurrogate) +``` + +Because `MySurrogateDataModel` inherits from `BotorchSurrogate`, the registration automatically adds it to the `BotorchSurrogates` collection. This means it can be used as a surrogate specification in any BoTorch-based strategy without extra work: + +```python +from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates + +specs = BotorchSurrogates(surrogates=[ + MySurrogateDataModel(inputs=domain.inputs, outputs=domain.outputs), +]) +``` + +### Data model transforms + +Some surrogates are conceptually a special case of another. For example, `TanimotoGPSurrogate` is converted to a `SingleTaskGPSurrogate` before instantiation. You can do the same with the `data_model_transform` argument: + +```python +def my_transform(data_model): + """Convert MySurrogateDataModel to SingleTaskGPSurrogate.""" + from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate + return SingleTaskGPSurrogate( + inputs=data_model.inputs, + outputs=data_model.outputs, + kernel=data_model.kernel, + noise_prior=data_model.noise_prior, + ) + +surrogates.register( + MySurrogateDataModel, + SingleTaskGPSurrogate, # functional class after transform + data_model_transform=my_transform, +) +``` + +## Registering a custom kernel + +Custom GP kernels can be registered so they are accepted in all kernel fields (surrogate models, aggregation kernels, etc.): + +```python +from typing import Literal + +import gpytorch +import torch + +import bofire.kernels.api as kernels +from bofire.data_models.kernels.kernel import Kernel + + +# 1. Data model +class MyKernelDataModel(Kernel): + type: Literal["MyKernel"] = "MyKernel" + lengthscale: float = 1.0 + + +# 2. Mapping function — returns a gpytorch kernel +@kernels.register(MyKernelDataModel) +def map_my_kernel(data_model, batch_shape, active_dims, features_to_idx_mapper): + k = gpytorch.kernels.RBFKernel( + batch_shape=batch_shape, + active_dims=active_dims, + ) + k.lengthscale = torch.tensor(data_model.lengthscale) + return k +``` + +After registration the custom kernel is accepted anywhere a built-in kernel is accepted: + +```python +from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate +from bofire.data_models.kernels.aggregation import ScaleKernel, AdditiveKernel + +# As the main kernel of a GP +surrogate = SingleTaskGPSurrogate( + inputs=..., outputs=..., + kernel=MyKernelDataModel(lengthscale=2.0), +) + +# Inside aggregation kernels +additive = AdditiveKernel(kernels=[MyKernelDataModel(), ...]) +scaled = ScaleKernel(base_kernel=MyKernelDataModel()) +``` + +If the custom kernel subclasses `ContinuousKernel` or `CategoricalKernel`, it is also added to the corresponding sub-union (`AnyContinuousKernel` / `AnyCategoricalKernel`) and accepted in `MixedSingleTaskGPSurrogate`. + +## Registering a custom prior + +Custom GP priors work the same way: + +```python +from typing import Literal + +import gpytorch + +import bofire.priors.api as priors +from bofire.data_models.priors.prior import Prior + + +class MyPriorDataModel(Prior): + type: Literal["MyPrior"] = "MyPrior" + loc: float = 0.0 + scale: float = 1.0 + + +@priors.register(MyPriorDataModel) +def map_my_prior(data_model, **kwargs): + return gpytorch.priors.NormalPrior( + loc=data_model.loc, + scale=data_model.scale, + ) +``` + +After registration the custom prior is accepted in all prior fields (`noise_prior`, `lengthscale_prior`, etc.) across all surrogate and kernel data models. + +## How it works + +BoFire uses [Pydantic discriminated unions](https://docs.pydantic.dev/latest/concepts/unions/) to validate fields such as `SingleTaskGPSurrogate.kernel` or `BotorchSurrogates.surrogates`. These unions are defined as module-level type aliases (e.g. `AnyKernel`, `AnyPrior`, `AnyBotorchSurrogate`). + +When you call a `register` function, BoFire: + +1. Appends your type to the internal type list backing the union. +2. Rebuilds the `Union` type alias. +3. Patches the annotation on every Pydantic model field that references that union. +4. Calls `model_rebuild(force=True)` on affected models so Pydantic picks up the new validator. + +This means registration has a small one-time cost at import time and should be done early — ideally at module level, before creating any data model instances. diff --git a/pyproject.toml b/pyproject.toml index fb49eff91..6383f4fb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,9 @@ invalid-assignment = "warn" invalid-return-type = "warn" # BoFire intentionally narrows parameter types in subclass overrides. invalid-method-override = "warn" +# Dynamic union rebuilding uses Union[tuple(list)] which is valid at +# runtime but not statically analysable. +invalid-type-form = "ignore" [tool.ty.analysis] # Optional dependencies that may not be installed or lack py.typed markers diff --git a/tests/bofire/test_register.py b/tests/bofire/test_register.py new file mode 100644 index 000000000..0b5c4f586 --- /dev/null +++ b/tests/bofire/test_register.py @@ -0,0 +1,1033 @@ +import inspect +from typing import Literal, Type +from unittest.mock import MagicMock + +import pandas as pd + +from bofire.data_models.constraints.api import Constraint +from bofire.data_models.domain.api import Domain, Inputs, Outputs +from bofire.data_models.features.api import ( + AnyOutput, + CategoricalInput, + ContinuousInput, + ContinuousOutput, + EngineeredFeature, + Feature, +) +from bofire.data_models.kernels.continuous import ContinuousKernel as _ContinuousBase +from bofire.data_models.kernels.kernel import Kernel as KernelDataModel +from bofire.data_models.priors.prior import Prior as PriorDataModel +from bofire.data_models.strategies.strategy import Strategy as StrategyDataModel +from bofire.data_models.surrogates.botorch import BotorchSurrogate +from bofire.data_models.surrogates.surrogate import Surrogate as SurrogateDataModel +from bofire.strategies.strategy import Strategy +from bofire.surrogates.surrogate import Surrogate + + +# --------------------------------------------------------------------------- +# Stub data model and functional classes for strategies +# --------------------------------------------------------------------------- + + +class _CustomStrategyDataModel(StrategyDataModel): + type: str = "CustomStrategy" + + def is_constraint_implemented(self, my_type: Type[Constraint]) -> bool: + return True + + @classmethod + def is_feature_implemented(cls, my_type: Type[Feature]) -> bool: + return True + + +class _CustomStrategy(Strategy): + def _ask(self, candidate_count): + return pd.DataFrame() + + def has_sufficient_experiments(self) -> bool: + return True + + +# --------------------------------------------------------------------------- +# Stub data model and functional classes for surrogates +# --------------------------------------------------------------------------- + +_INPUTS = Inputs(features=[ContinuousInput(key="x", bounds=(0, 1))]) +_OUTPUTS = Outputs(features=[ContinuousOutput(key="y")]) + + +class _CustomSurrogateDataModel(SurrogateDataModel): + type: str = "CustomSurrogate" + inputs: Inputs = _INPUTS + outputs: Outputs = _OUTPUTS + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + return True + + +class _CustomSurrogate(Surrogate): + def __init__(self, data_model, **kwargs): + # simplified init that skips heavy base-class logic + self.data_model = data_model + + def predict(self, X): + pass + + def _predict(self, transformed_X): + pass + + def loads(self, data): + pass + + def dumps(self): + return "" + + def _dumps(self): + return "" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_domain(): + return Domain( + inputs=Inputs(features=[ContinuousInput(key="x", bounds=(0, 1))]), + outputs=Outputs(features=[ContinuousOutput(key="y")]), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRegisterStrategy: + def test_register_and_map(self): + import bofire.strategies.api as strategies_api + from bofire.strategies.mapper_actual import STRATEGY_MAP as ACTUAL_MAP + + # ensure not registered yet + ACTUAL_MAP.pop(_CustomStrategyDataModel, None) + + strategies_api.register(_CustomStrategyDataModel, _CustomStrategy) + assert ACTUAL_MAP[_CustomStrategyDataModel] is _CustomStrategy + + # round-trip through strategies.map + dm = _CustomStrategyDataModel(domain=_make_domain()) + result = strategies_api.map(dm) + assert isinstance(result, _CustomStrategy) + + # cleanup + ACTUAL_MAP.pop(_CustomStrategyDataModel, None) + + def test_register_decorator_syntax(self): + import bofire.strategies.api as strategies_api + from bofire.strategies.mapper_actual import STRATEGY_MAP as ACTUAL_MAP + + ACTUAL_MAP.pop(_CustomStrategyDataModel, None) + + @strategies_api.register(_CustomStrategyDataModel) + class _DecoratedStrategy(_CustomStrategy): + pass + + assert ACTUAL_MAP[_CustomStrategyDataModel] is _DecoratedStrategy + + dm = _CustomStrategyDataModel(domain=_make_domain()) + result = strategies_api.map(dm) + assert isinstance(result, _DecoratedStrategy) + + # cleanup + ACTUAL_MAP.pop(_CustomStrategyDataModel, None) + + def test_register_updates_stepwise_strategy(self): + """Registering a strategy should update the StepwiseStrategy data model.""" + import bofire.strategies.api as strategies_api + from bofire.data_models.strategies.actual_strategy_type import ( + _ACTUAL_STRATEGY_TYPES, + ) + from bofire.data_models.strategies.stepwise.stepwise import Step + from bofire.strategies.mapper_actual import STRATEGY_MAP as ACTUAL_MAP + + ACTUAL_MAP.pop(_CustomStrategyDataModel, None) + if _CustomStrategyDataModel in _ACTUAL_STRATEGY_TYPES: + _ACTUAL_STRATEGY_TYPES.remove(_CustomStrategyDataModel) + + strategies_api.register(_CustomStrategyDataModel, _CustomStrategy) + + # Step.strategy_data should now accept the custom type + step = Step( + strategy_data=_CustomStrategyDataModel(domain=_make_domain()), + condition={"type": "AlwaysTrueCondition"}, + ) + assert type(step.strategy_data) is _CustomStrategyDataModel + + # cleanup + ACTUAL_MAP.pop(_CustomStrategyDataModel, None) + if _CustomStrategyDataModel in _ACTUAL_STRATEGY_TYPES: + _ACTUAL_STRATEGY_TYPES.remove(_CustomStrategyDataModel) + + def test_register_exported_from_api(self): + from bofire.strategies.api import register + + assert callable(register) + + +class TestRegisterSurrogate: + def test_register_and_map(self): + import bofire.surrogates.api as surrogates_api + from bofire.surrogates.mapper import SURROGATE_MAP + + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + + surrogates_api.register(_CustomSurrogateDataModel, _CustomSurrogate) + assert SURROGATE_MAP[_CustomSurrogateDataModel] is _CustomSurrogate + + dm = _CustomSurrogateDataModel() + result = surrogates_api.map(dm) + assert isinstance(result, _CustomSurrogate) + + # cleanup + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + + def test_register_with_data_model_transform(self): + import bofire.surrogates.api as surrogates_api + from bofire.surrogates.mapper import DATA_MODEL_MAP, SURROGATE_MAP + + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + DATA_MODEL_MAP.pop(_CustomSurrogateDataModel, None) + + transform_called_with = [] + + def my_transform(dm): + transform_called_with.append(dm) + # Return the same data model (identity transform) + return dm + + surrogates_api.register( + _CustomSurrogateDataModel, + _CustomSurrogate, + data_model_transform=my_transform, + ) + assert SURROGATE_MAP[_CustomSurrogateDataModel] is _CustomSurrogate + assert DATA_MODEL_MAP[_CustomSurrogateDataModel] is my_transform + + dm = _CustomSurrogateDataModel() + result = surrogates_api.map(dm) + assert isinstance(result, _CustomSurrogate) + assert len(transform_called_with) == 1 + assert transform_called_with[0] is dm + + # cleanup + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + DATA_MODEL_MAP.pop(_CustomSurrogateDataModel, None) + + def test_register_without_data_model_transform_does_not_add_to_data_model_map( + self, + ): + import bofire.surrogates.api as surrogates_api + from bofire.surrogates.mapper import DATA_MODEL_MAP, SURROGATE_MAP + + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + DATA_MODEL_MAP.pop(_CustomSurrogateDataModel, None) + + surrogates_api.register(_CustomSurrogateDataModel, _CustomSurrogate) + assert _CustomSurrogateDataModel not in DATA_MODEL_MAP + + # cleanup + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + + def test_register_decorator_syntax(self): + import bofire.surrogates.api as surrogates_api + from bofire.surrogates.mapper import SURROGATE_MAP + + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + + @surrogates_api.register(_CustomSurrogateDataModel) + class _DecoratedSurrogate(_CustomSurrogate): + pass + + assert SURROGATE_MAP[_CustomSurrogateDataModel] is _DecoratedSurrogate + + dm = _CustomSurrogateDataModel() + result = surrogates_api.map(dm) + assert isinstance(result, _DecoratedSurrogate) + + # cleanup + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + + def test_register_decorator_with_transform(self): + import bofire.surrogates.api as surrogates_api + from bofire.surrogates.mapper import DATA_MODEL_MAP, SURROGATE_MAP + + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + DATA_MODEL_MAP.pop(_CustomSurrogateDataModel, None) + + transform_called = [] + + def my_transform(dm): + transform_called.append(dm) + return dm + + @surrogates_api.register( + _CustomSurrogateDataModel, data_model_transform=my_transform + ) + class _DecoratedSurrogate(_CustomSurrogate): + pass + + assert SURROGATE_MAP[_CustomSurrogateDataModel] is _DecoratedSurrogate + assert DATA_MODEL_MAP[_CustomSurrogateDataModel] is my_transform + + dm = _CustomSurrogateDataModel() + result = surrogates_api.map(dm) + assert isinstance(result, _DecoratedSurrogate) + assert len(transform_called) == 1 + + # cleanup + SURROGATE_MAP.pop(_CustomSurrogateDataModel, None) + DATA_MODEL_MAP.pop(_CustomSurrogateDataModel, None) + + def test_register_exported_from_api(self): + from bofire.surrogates.api import register + + assert callable(register) + + +# --------------------------------------------------------------------------- +# Stub for botorch surrogate registration +# --------------------------------------------------------------------------- + + +class _CustomBotorchSurrogateDataModel(BotorchSurrogate): + type: Literal["_CustomBotorchSurrogate"] = "_CustomBotorchSurrogate" + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + return True + + +class _CustomBotorchSurrogate(Surrogate): + def __init__(self, data_model, **kwargs): + self.data_model = data_model + + def predict(self, X): + pass + + def _predict(self, transformed_X): + pass + + def loads(self, data): + pass + + def dumps(self): + return "" + + def _dumps(self): + return "" + + +class TestRegisterBotorchSurrogate: + def _cleanup(self): + from bofire.data_models.surrogates.botorch_surrogates import ( + _BOTORCH_SURROGATE_TYPES, + ) + from bofire.surrogates.mapper import DATA_MODEL_MAP, SURROGATE_MAP + + SURROGATE_MAP.pop(_CustomBotorchSurrogateDataModel, None) + DATA_MODEL_MAP.pop(_CustomBotorchSurrogateDataModel, None) + if _CustomBotorchSurrogateDataModel in _BOTORCH_SURROGATE_TYPES: + _BOTORCH_SURROGATE_TYPES.remove(_CustomBotorchSurrogateDataModel) + + def test_register_adds_to_botorch_surrogates(self): + """Registering a BotorchSurrogate subclass should also update + AnyBotorchSurrogate so that BotorchSurrogates accepts it.""" + import typing + + import bofire.surrogates.api as surrogates_api + from bofire.data_models.surrogates.botorch_surrogates import ( + _BOTORCH_SURROGATE_TYPES, + BotorchSurrogates, + ) + + self._cleanup() + n_before = len(_BOTORCH_SURROGATE_TYPES) + + surrogates_api.register( + _CustomBotorchSurrogateDataModel, _CustomBotorchSurrogate + ) + + # type was appended to the registry + assert _CustomBotorchSurrogateDataModel in _BOTORCH_SURROGATE_TYPES + assert len(_BOTORCH_SURROGATE_TYPES) == n_before + 1 + + # BotorchSurrogates now accepts our custom surrogate + dm = _CustomBotorchSurrogateDataModel( + inputs=_INPUTS, + outputs=_OUTPUTS, + ) + bs = BotorchSurrogates(surrogates=[dm]) + assert len(bs.surrogates) == 1 + assert isinstance(bs.surrogates[0], _CustomBotorchSurrogateDataModel) + + # the module-level AnyBotorchSurrogate union includes our type + from bofire.data_models.surrogates import botorch_surrogates + + args = typing.get_args(botorch_surrogates.AnyBotorchSurrogate) + assert _CustomBotorchSurrogateDataModel in args + + self._cleanup() + + def test_register_idempotent(self): + """Calling register twice with the same type should not duplicate it.""" + import bofire.surrogates.api as surrogates_api + from bofire.data_models.surrogates.botorch_surrogates import ( + _BOTORCH_SURROGATE_TYPES, + ) + + self._cleanup() + + surrogates_api.register( + _CustomBotorchSurrogateDataModel, _CustomBotorchSurrogate + ) + surrogates_api.register( + _CustomBotorchSurrogateDataModel, _CustomBotorchSurrogate + ) + + count = _BOTORCH_SURROGATE_TYPES.count(_CustomBotorchSurrogateDataModel) + assert count == 1 + + self._cleanup() + + def test_map_round_trip(self): + """A registered botorch surrogate should be mappable via surrogates.map.""" + import bofire.surrogates.api as surrogates_api + + self._cleanup() + + surrogates_api.register( + _CustomBotorchSurrogateDataModel, _CustomBotorchSurrogate + ) + + dm = _CustomBotorchSurrogateDataModel( + inputs=_INPUTS, + outputs=_OUTPUTS, + ) + result = surrogates_api.map(dm) + assert isinstance(result, _CustomBotorchSurrogate) + + self._cleanup() + + def test_register_decorator_syntax(self): + """Decorator syntax should also trigger botorch registration.""" + import bofire.surrogates.api as surrogates_api + from bofire.data_models.surrogates.botorch_surrogates import ( + _BOTORCH_SURROGATE_TYPES, + ) + + self._cleanup() + + @surrogates_api.register(_CustomBotorchSurrogateDataModel) + class _DecoratedBotorchSurrogate(_CustomBotorchSurrogate): + pass + + assert _CustomBotorchSurrogateDataModel in _BOTORCH_SURROGATE_TYPES + + dm = _CustomBotorchSurrogateDataModel( + inputs=_INPUTS, + outputs=_OUTPUTS, + ) + result = surrogates_api.map(dm) + assert isinstance(result, _DecoratedBotorchSurrogate) + + self._cleanup() + + +# --------------------------------------------------------------------------- +# Stub data models for kernels and priors +# --------------------------------------------------------------------------- + + +class _CustomKernelDataModel(KernelDataModel): + type: str = "CustomKernel" + + +class _CustomPriorDataModel(PriorDataModel): + type: str = "CustomPrior" + + +# --------------------------------------------------------------------------- +# Kernel registration tests +# --------------------------------------------------------------------------- + + +class TestRegisterKernel: + def test_register_and_map(self): + import bofire.kernels.api as kernels_api + from bofire.kernels.mapper import KERNEL_MAP + + KERNEL_MAP.pop(_CustomKernelDataModel, None) + + sentinel = MagicMock(name="gpytorch_kernel") + + def my_map_fn(data_model, batch_shape, active_dims, features_to_idx_mapper): + return sentinel + + kernels_api.register(_CustomKernelDataModel, my_map_fn) + assert KERNEL_MAP[_CustomKernelDataModel] is my_map_fn + + import torch + + dm = _CustomKernelDataModel() + result = kernels_api.map(dm, torch.Size(), [0], None) + assert result is sentinel + + # cleanup + KERNEL_MAP.pop(_CustomKernelDataModel, None) + + def test_register_decorator_syntax(self): + import bofire.kernels.api as kernels_api + from bofire.kernels.mapper import KERNEL_MAP + + KERNEL_MAP.pop(_CustomKernelDataModel, None) + + sentinel = MagicMock(name="gpytorch_kernel") + + @kernels_api.register(_CustomKernelDataModel) + def my_map_fn(data_model, batch_shape, active_dims, features_to_idx_mapper): + return sentinel + + assert KERNEL_MAP[_CustomKernelDataModel] is my_map_fn + + import torch + + dm = _CustomKernelDataModel() + result = kernels_api.map(dm, torch.Size(), [0], None) + assert result is sentinel + + # cleanup + KERNEL_MAP.pop(_CustomKernelDataModel, None) + + def test_register_exported_from_api(self): + from bofire.kernels.api import register + + assert callable(register) + + +# --------------------------------------------------------------------------- +# Prior registration tests +# --------------------------------------------------------------------------- + + +class TestRegisterPrior: + def test_register_and_map(self): + import bofire.priors.api as priors_api + from bofire.priors.mapper import PRIOR_MAP + + PRIOR_MAP.pop(_CustomPriorDataModel, None) + + sentinel = MagicMock(name="gpytorch_prior") + + def my_map_fn(data_model, **kwargs): + return sentinel + + priors_api.register(_CustomPriorDataModel, my_map_fn) + assert PRIOR_MAP[_CustomPriorDataModel] is my_map_fn + + dm = _CustomPriorDataModel() + result = priors_api.map(dm) + assert result is sentinel + + # cleanup + PRIOR_MAP.pop(_CustomPriorDataModel, None) + + def test_register_decorator_syntax(self): + import bofire.priors.api as priors_api + from bofire.priors.mapper import PRIOR_MAP + + PRIOR_MAP.pop(_CustomPriorDataModel, None) + + sentinel = MagicMock(name="gpytorch_prior") + + @priors_api.register(_CustomPriorDataModel) + def my_map_fn(data_model, **kwargs): + return sentinel + + assert PRIOR_MAP[_CustomPriorDataModel] is my_map_fn + + dm = _CustomPriorDataModel() + result = priors_api.map(dm) + assert result is sentinel + + # cleanup + PRIOR_MAP.pop(_CustomPriorDataModel, None) + + def test_register_exported_from_api(self): + from bofire.priors.api import register + + assert callable(register) + + +# --------------------------------------------------------------------------- +# Integration tests: custom types accepted by Pydantic validation +# --------------------------------------------------------------------------- + + +class _IntegrationKernelDataModel(KernelDataModel): + type: Literal["_IntegrationKernel"] = "_IntegrationKernel" + my_param: float = 1.0 + + +class _IntegrationContinuousKernel(_ContinuousBase): + type: Literal["_IntegrationContinuousKernel"] = "_IntegrationContinuousKernel" + + +class _IntegrationPriorDataModel(PriorDataModel): + type: Literal["_IntegrationPrior"] = "_IntegrationPrior" + value: float = 1.0 + + +class TestKernelPydanticIntegration: + """After register_kernel, custom kernel types should pass Pydantic validation + in surrogate models and aggregation kernels.""" + + def test_custom_kernel_in_surrogate(self): + from bofire.data_models.kernels.api import register_kernel + + register_kernel(_IntegrationKernelDataModel) + + from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate + + s = SingleTaskGPSurrogate( + inputs=_INPUTS, + outputs=_OUTPUTS, + kernel=_IntegrationKernelDataModel(my_param=42.0), + ) + assert isinstance(s.kernel, _IntegrationKernelDataModel) + assert s.kernel.my_param == 42.0 + + def test_custom_kernel_in_additive_kernel(self): + from bofire.data_models.kernels.aggregation import AdditiveKernel + from bofire.data_models.kernels.api import register_kernel + from bofire.data_models.kernels.continuous import RBFKernel + + register_kernel(_IntegrationKernelDataModel) + + ak = AdditiveKernel( + kernels=[RBFKernel(), _IntegrationKernelDataModel(my_param=7.0)] + ) + assert len(ak.kernels) == 2 + assert isinstance(ak.kernels[1], _IntegrationKernelDataModel) + + def test_custom_kernel_in_scale_kernel(self): + from bofire.data_models.kernels.aggregation import ScaleKernel + from bofire.data_models.kernels.api import register_kernel + + register_kernel(_IntegrationKernelDataModel) + + sk = ScaleKernel(base_kernel=_IntegrationKernelDataModel(my_param=3.0)) + assert isinstance(sk.base_kernel, _IntegrationKernelDataModel) + + def test_custom_continuous_kernel_in_mixed_surrogate(self): + """A ContinuousKernel subclass should be auto-added to AnyContinuousKernel.""" + from bofire.data_models.kernels.api import ( + _CONTINUOUS_KERNEL_TYPES, + register_kernel, + ) + from bofire.data_models.surrogates.mixed_single_task_gp import ( + MixedSingleTaskGPSurrogate, + ) + + register_kernel(_IntegrationContinuousKernel) + + assert _IntegrationContinuousKernel in _CONTINUOUS_KERNEL_TYPES + + s = MixedSingleTaskGPSurrogate( + inputs=Inputs( + features=[ + ContinuousInput(key="a", bounds=(0, 1)), + ContinuousInput(key="b", bounds=(0, 1)), + CategoricalInput(key="c", categories=["x", "y"]), + ] + ), + outputs=_OUTPUTS, + continuous_kernel=_IntegrationContinuousKernel(), + ) + assert isinstance(s.continuous_kernel, _IntegrationContinuousKernel) + + def test_custom_kernel_in_wedge_kernel(self): + """A registered kernel should be usable as WedgeKernel.base_kernel.""" + from bofire.data_models.kernels.api import register_kernel + from bofire.data_models.kernels.conditional import WedgeKernel + + class _WedgeTestKernel(KernelDataModel): + type: Literal["_WedgeTestKernel"] = "_WedgeTestKernel" + + register_kernel(_WedgeTestKernel) + + wk = WedgeKernel( + base_kernel=_WedgeTestKernel(), + conditions=[], + ) + assert isinstance(wk.base_kernel, _WedgeTestKernel) + + def test_mapper_register_also_updates_pydantic(self): + """The mapper-level register() should trigger data model registration.""" + import bofire.kernels.api as kernels_api + from bofire.data_models.kernels.api import _KERNEL_TYPES + + # Use a fresh class to ensure it's not already registered + class _MapperKernel(KernelDataModel): + type: Literal["_MapperKernel"] = "_MapperKernel" + + sentinel = MagicMock(name="gpytorch_kernel") + + kernels_api.register(_MapperKernel, lambda *a: sentinel) + + assert _MapperKernel in _KERNEL_TYPES + + from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate + + s = SingleTaskGPSurrogate( + inputs=_INPUTS, + outputs=_OUTPUTS, + kernel=_MapperKernel(), + ) + assert isinstance(s.kernel, _MapperKernel) + + +class TestPriorPydanticIntegration: + """After register_prior, custom prior types should pass Pydantic validation + in kernel and surrogate model fields.""" + + def test_custom_prior_as_noise_prior(self): + from bofire.data_models.priors.api import register_prior + + register_prior(_IntegrationPriorDataModel) + + from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate + + s = SingleTaskGPSurrogate( + inputs=_INPUTS, + outputs=_OUTPUTS, + noise_prior=_IntegrationPriorDataModel(value=2.0), + ) + assert isinstance(s.noise_prior, _IntegrationPriorDataModel) + assert s.noise_prior.value == 2.0 + + def test_custom_prior_as_lengthscale_prior(self): + from bofire.data_models.kernels.continuous import RBFKernel + from bofire.data_models.priors.api import register_prior + + register_prior(_IntegrationPriorDataModel) + + k = RBFKernel(lengthscale_prior=_IntegrationPriorDataModel(value=3.0)) + assert isinstance(k.lengthscale_prior, _IntegrationPriorDataModel) + + def test_mapper_register_also_updates_pydantic(self): + """The mapper-level register() should trigger data model registration.""" + import bofire.priors.api as priors_api + from bofire.data_models.priors.api import _PRIOR_TYPES + + class _MapperPrior(PriorDataModel): + type: Literal["_MapperPrior"] = "_MapperPrior" + + priors_api.register(_MapperPrior, lambda dm, **kw: MagicMock()) + + assert _MapperPrior in _PRIOR_TYPES + + from bofire.data_models.kernels.continuous import RBFKernel + + k = RBFKernel(lengthscale_prior=_MapperPrior()) + assert isinstance(k.lengthscale_prior, _MapperPrior) + + +# --------------------------------------------------------------------------- +# Engineered feature registration tests +# --------------------------------------------------------------------------- + + +class _IntegrationEngineeredFeature(EngineeredFeature): + type: Literal["_IntegrationEngineered"] = "_IntegrationEngineered" + order_id = 99 + + @property + def n_transformed_inputs(self) -> int: + return 1 + + +class TestEngineeredFeatureRegistration: + def test_register_data_model(self): + from bofire.data_models.domain.features import EngineeredFeatures + from bofire.data_models.features.api import register_engineered_feature + + register_engineered_feature(_IntegrationEngineeredFeature) + + ef = EngineeredFeatures( + features=[_IntegrationEngineeredFeature(key="test", features=["a", "b"])] + ) + assert isinstance(ef.features[0], _IntegrationEngineeredFeature) + + def test_register_in_surrogate(self): + from bofire.data_models.domain.features import EngineeredFeatures + from bofire.data_models.features.api import ( + ContinuousInput, + register_engineered_feature, + ) + from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate + + register_engineered_feature(_IntegrationEngineeredFeature) + + ef = EngineeredFeatures( + features=[_IntegrationEngineeredFeature(key="test", features=["a", "b"])] + ) + s = SingleTaskGPSurrogate( + inputs=Inputs( + features=[ + ContinuousInput(key="a", bounds=(0, 1)), + ContinuousInput(key="b", bounds=(0, 1)), + ] + ), + outputs=_OUTPUTS, + engineered_features=ef, + ) + assert len(s.engineered_features.features) == 1 + + def test_mapper_register_decorator(self): + from bofire.data_models.features.api import _ENGINEERED_FEATURE_TYPES + from bofire.surrogates.engineered_features import AGGREGATE_MAP, register + + class _MapperEngineered(EngineeredFeature): + type: Literal["_MapperEngineered"] = "_MapperEngineered" + order_id = 100 + + @property + def n_transformed_inputs(self) -> int: + return 1 + + sentinel = MagicMock(name="append_features") + + @register(_MapperEngineered) + def my_map_fn(inputs, transform_specs, feature): + return sentinel + + assert AGGREGATE_MAP[_MapperEngineered] is my_map_fn + assert _MapperEngineered in _ENGINEERED_FEATURE_TYPES + + def test_mapper_register_direct_call(self): + from bofire.data_models.features.api import _ENGINEERED_FEATURE_TYPES + from bofire.surrogates.engineered_features import AGGREGATE_MAP, register + + class _DirectEngineered(EngineeredFeature): + type: Literal["_DirectEngineered"] = "_DirectEngineered" + order_id = 101 + + @property + def n_transformed_inputs(self) -> int: + return 1 + + sentinel = MagicMock(name="append_features") + register(_DirectEngineered, lambda i, t, f: sentinel) + + assert _DirectEngineered in AGGREGATE_MAP + assert _DirectEngineered in _ENGINEERED_FEATURE_TYPES + + +# --------------------------------------------------------------------------- +# Introspection test: ensure _rebuild_dependent_models covers all fields +# --------------------------------------------------------------------------- + + +class TestRebuildCoverage: + """Verify that the explicit field lists in _rebuild_dependent_models + cover every Pydantic model field typed with AnyPrior, AnyPriorConstraint, + or AnyKernel. This catches regressions when new surrogates or kernels + are added without updating the rebuild functions.""" + + @staticmethod + def _collect_fields(base_module: str, target_union_names: set[str]): + """Walk all Pydantic models under *base_module* and return + ``{(ModelClass, field_name)}`` for every field whose annotation + string matches one of the *target_union_names*. + """ + import importlib + import inspect + import pkgutil + + import pydantic + + pkg = importlib.import_module(base_module) + found: set[tuple[type, str]] = set() + + for _importer, modname, _ispkg in pkgutil.walk_packages( + pkg.__path__, prefix=pkg.__name__ + "." + ): + try: + mod = importlib.import_module(modname) + except Exception: + continue + for _name, obj in inspect.getmembers(mod, inspect.isclass): + if ( + not issubclass(obj, pydantic.BaseModel) + or obj is pydantic.BaseModel + or obj.__module__ != modname + ): + continue + for field_name, field_info in obj.model_fields.items(): + ann = field_info.annotation + ann_str = str(ann) + for target in target_union_names: + if target in ann_str: + found.add((obj, field_name)) + return found + + def test_prior_rebuild_covers_all_anyprior_fields(self): + """Every model field typed as AnyPrior should appear in + priors._rebuild_dependent_models.""" + from bofire.data_models.priors._register import _rebuild_dependent_models + + src = inspect.getsource(_rebuild_dependent_models) + + patched: set[tuple[str, str]] = set() + + # The function calls patch_field(Model, "field", union) — extract those pairs + import re + + for match in re.finditer(r"\((\w+),\s*\"(\w+)\"\)", src): + model_name, field_name = match.group(1), match.group(2) + patched.add((model_name, field_name)) + + # Now collect all fields in the codebase typed with AnyPrior or AnyPriorConstraint + all_fields = self._collect_fields( + "bofire.data_models", {"AnyPrior", "AnyPriorConstraint"} + ) + + # Convert to (class_name, field_name) for comparison + all_field_names = {(cls.__name__, fname) for cls, fname in all_fields} + + missing = all_field_names - patched + assert not missing, ( + f"Fields typed as AnyPrior/AnyPriorConstraint not covered by " + f"_rebuild_dependent_models: {missing}" + ) + + def test_kernel_rebuild_covers_all_anykernel_fields(self): + """Every model field typed as AnyKernel or as an inline Union of + Kernel subclasses should appear in kernels._rebuild_dependent_models + or be handled via append_to_union_field.""" + + from bofire.data_models.kernels._register import _rebuild_dependent_models + from bofire.data_models.kernels.kernel import Kernel + + src = inspect.getsource(_rebuild_dependent_models) + + import re + + patched: set[tuple[str, str]] = set() + for match in re.finditer(r"\((\w+),\s*\"(\w+)\"\)", src): + model_name, field_name = match.group(1), match.group(2) + patched.add((model_name, field_name)) + + # Collect AnyKernel-typed fields + all_fields = self._collect_fields("bofire.data_models", {"AnyKernel"}) + all_field_names = {(cls.__name__, fname) for cls, fname in all_fields} + + # AnyContinuousKernel / AnyCategoricalKernel are handled separately + # so exclude them from this check + continuous_cat_fields = self._collect_fields( + "bofire.data_models", {"AnyContinuousKernel", "AnyCategoricalKernel"} + ) + continuous_cat_names = { + (cls.__name__, fname) for cls, fname in continuous_cat_fields + } + + # Only check pure AnyKernel fields + anykernel_only = all_field_names - continuous_cat_names + + # Also find inline Union fields whose members are all Kernel subclasses + # (like ConditionalEmbeddingKernel.base_kernel) + inline_kernel_fields = self._collect_inline_kernel_union_fields( + "bofire.data_models", Kernel + ) + inline_names = {(cls.__name__, fname) for cls, fname in inline_kernel_fields} + + all_kernel_fields = anykernel_only | inline_names + missing = all_kernel_fields - patched + assert not missing, ( + f"Fields typed as AnyKernel or inline Union[Kernel, ...] not covered by " + f"_rebuild_dependent_models: {missing}" + ) + + @staticmethod + def _collect_inline_kernel_union_fields( + base_module: str, kernel_base: type + ) -> set[tuple[type, str]]: + """Find inline Union-of-Kernel fields on Kernel container classes. + + Only scans classes that are themselves Kernel subclasses (aggregation + and conditional kernels) to find fields like ``base_kernel`` or + ``kernels`` that accept kernel instances. Surrogate models with + intentionally restricted kernel fields are excluded because those + restrictions are by design. + + Also skips inherited fields (e.g. WedgeKernel.base_kernel inherited + from ConditionalEmbeddingKernel) to avoid double-counting. + """ + import importlib + import pkgutil + import typing + from collections.abc import Sequence + + import pydantic + + pkg = importlib.import_module(base_module) + found: set[tuple[type, str]] = set() + + for _importer, modname, _ispkg in pkgutil.walk_packages( + pkg.__path__, prefix=pkg.__name__ + "." + ): + try: + mod = importlib.import_module(modname) + except Exception: + continue + for _name, obj in inspect.getmembers(mod, inspect.isclass): + if ( + not issubclass(obj, pydantic.BaseModel) + or obj is pydantic.BaseModel + or obj.__module__ != modname + ): + continue + # Only check Kernel subclasses (kernel containers) + if not issubclass(obj, kernel_base): + continue + for field_name, field_info in obj.model_fields.items(): + # Skip inherited fields — only check fields declared + # directly on this class + if field_name not in obj.__annotations__: + continue + + ann = field_info.annotation + ann_str = str(ann) + if "AnyKernel" in ann_str: + continue + + # Unwrap Sequence[Union[...]] if needed + inner = ann + origin = typing.get_origin(ann) + if origin in (list, Sequence): + inner_args = typing.get_args(ann) + if inner_args: + inner = inner_args[0] + + # Check if this is a Union whose args are all Kernel subclasses + if typing.get_origin(inner) is not typing.Union: + continue + args = typing.get_args(inner) + non_none = [a for a in args if a is not type(None)] + if non_none and all( + isinstance(a, type) and issubclass(a, kernel_base) + for a in non_none + ): + found.add((obj, field_name)) + return found