Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ website:
href: docs/userguides/surrogates.qmd
- text: "Data Models vs Functionals"
href: docs/userguides/data_models_functionals.qmd
- text: "Registering Custom Types"
href: docs/userguides/custom_types.qmd
- text: "Examples"
menu:
- text: "Basic Examples"
Expand Down Expand Up @@ -192,6 +194,7 @@ website:
- docs/userguides/strategies.qmd
- docs/userguides/surrogates.qmd
- docs/userguides/data_models_functionals.qmd
- docs/userguides/custom_types.qmd

- title: "Examples"
contents:
Expand Down
72 changes: 72 additions & 0 deletions bofire/data_models/_register_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Shared utilities for dynamic Pydantic model registration."""

import typing
from collections.abc import Sequence
from typing import Optional, Union


def patch_field(model_cls: type, field_name: str, new_union: type) -> None:
"""Patch a Pydantic model field annotation with a new union type.

Handles three annotation forms:
- ``Union[A, B, ...]`` — replaced with *new_union*
- ``Optional[Union[A, B, ...]]`` — wrapped as ``Optional[new_union]``
- ``Sequence[Union[A, B, ...]]`` — wrapped as ``Sequence[new_union]``
"""
old = model_cls.model_fields[ # ty: ignore[unresolved-attribute]
field_name
].annotation
args = typing.get_args(old)

if not args:
new = new_union
elif type(None) in args:
# Optional[X] is Union[X, None]
new = Optional[new_union]
elif typing.get_origin(old) in (list, Sequence):
# Sequence[Union[...]] or list[Union[...]]
new = Sequence[new_union]
else:
new = new_union

model_cls.__annotations__[field_name] = new
model_cls.model_fields[ # ty: ignore[unresolved-attribute]
field_name
].annotation = new


def append_to_union_field(model_cls: type, field_name: str, new_type: type) -> None:
"""Append a type to the union inside a model field annotation.

Detects the annotation structure (plain ``Union``, ``Optional[Union]``,
or ``Sequence[Union]``) and appends *new_type* to the inner union.
"""
old = model_cls.model_fields[ # ty: ignore[unresolved-attribute]
field_name
].annotation
origin = typing.get_origin(old)

if origin in (list, Sequence):
# Sequence[Union[A, B, ...]] → Sequence[Union[A, B, ..., new_type]]
inner = typing.get_args(old)[0]
inner_args = typing.get_args(inner)
if new_type not in inner_args:
new_inner = Union[tuple(list(inner_args) + [new_type])]
new_ann = Sequence[new_inner]
model_cls.__annotations__[field_name] = new_ann
model_cls.model_fields[ # ty: ignore[unresolved-attribute]
field_name
].annotation = new_ann
else:
# Union[A, B, ...] or Optional[Union[A, B, ...]]
args = typing.get_args(old)
has_none = type(None) in args
non_none = [a for a in args if a is not type(None)]
if new_type not in non_none:
non_none.append(new_type)
new_union = Union[tuple(non_none)]
new_ann = Optional[new_union] if has_none else new_union
model_cls.__annotations__[field_name] = new_ann
model_cls.model_fields[ # ty: ignore[unresolved-attribute]
field_name
].annotation = new_ann
29 changes: 29 additions & 0 deletions bofire/data_models/features/_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Registration utilities for custom engineered feature types."""

from typing import Union


def register_engineered_feature(data_model_cls: type) -> None:
"""Register a custom engineered feature type so it is accepted in EngineeredFeatures.

This appends the type to the internal registry, rebuilds the
``AnyEngineeredFeature`` union, and calls ``model_rebuild`` on
``EngineeredFeatures`` so that Pydantic accepts the new type.

Args:
data_model_cls: A concrete subclass of ``EngineeredFeature``.
"""
import bofire.data_models.features.api as features_api

if data_model_cls in features_api._ENGINEERED_FEATURE_TYPES:
return
features_api._ENGINEERED_FEATURE_TYPES.append(data_model_cls)
features_api.AnyEngineeredFeature = Union[
tuple(features_api._ENGINEERED_FEATURE_TYPES)
]

from bofire.data_models._register_utils import append_to_union_field
from bofire.data_models.domain.features import EngineeredFeatures

append_to_union_field(EngineeredFeatures, "features", data_model_cls)
EngineeredFeatures.model_rebuild(force=True)
5 changes: 4 additions & 1 deletion bofire/data_models/features/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union

from bofire.data_models.features._register import register_engineered_feature
from bofire.data_models.features.categorical import CategoricalInput, CategoricalOutput
from bofire.data_models.features.continuous import ContinuousInput, ContinuousOutput
from bofire.data_models.features.descriptor import (
Expand Down Expand Up @@ -71,7 +72,7 @@

AnyOutput = Union[ContinuousOutput, CategoricalOutput]

AnyEngineeredFeature = Union[
_ENGINEERED_FEATURE_TYPES: list[type[EngineeredFeature]] = [
SumFeature,
MeanFeature,
WeightedMeanFeature,
Expand All @@ -82,3 +83,5 @@
InterpolateFeature,
CloneFeature,
]

AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)]
124 changes: 124 additions & 0 deletions bofire/data_models/kernels/_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Registration utilities for custom kernel types."""

from typing import Union


def _rebuild_dependent_models(new_kernel_cls: type) -> None:
"""Rebuild all Pydantic models whose fields reference kernel unions."""
import bofire.data_models.kernels.api as kernels_api
from bofire.data_models._register_utils import append_to_union_field, patch_field
from bofire.data_models.kernels.aggregation import (
AdditiveKernel,
MultiplicativeKernel,
PolynomialFeatureInteractionKernel,
ScaleKernel,
)
from bofire.data_models.kernels.categorical import CategoricalKernel
from bofire.data_models.kernels.conditional import (
ConditionalEmbeddingKernel,
WedgeKernel,
)
from bofire.data_models.kernels.continuous import ContinuousKernel
from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates
from bofire.data_models.surrogates.mixed_single_task_gp import (
MixedSingleTaskGPSurrogate,
)
from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate

# Add new kernel type to aggregation kernel inline unions
# (handles both Sequence[Union[...]] and plain Union[...])
for model_cls, field_name in [
(AdditiveKernel, "kernels"),
(MultiplicativeKernel, "kernels"),
(PolynomialFeatureInteractionKernel, "kernels"),
(ScaleKernel, "base_kernel"),
(ConditionalEmbeddingKernel, "base_kernel"),
(WedgeKernel, "base_kernel"),
]:
append_to_union_field(model_cls, field_name, new_kernel_cls)

# Rebuild aggregation and conditional kernels
for cls in [
AdditiveKernel,
MultiplicativeKernel,
ScaleKernel,
PolynomialFeatureInteractionKernel,
ConditionalEmbeddingKernel,
WedgeKernel,
]:
cls.model_rebuild(force=True)

# Patch AnyKernel fields on surrogate models
for model_cls, field_name in [
(SingleTaskGPSurrogate, "kernel"),
(MultiTaskGPSurrogate, "kernel"),
(TanimotoGPSurrogate, "kernel"),
]:
patch_field(model_cls, field_name, kernels_api.AnyKernel)

# Patch sub-category kernel fields if the new type is a subclass
if issubclass(new_kernel_cls, ContinuousKernel):
patch_field(
MixedSingleTaskGPSurrogate,
"continuous_kernel",
kernels_api.AnyContinuousKernel,
)
if issubclass(new_kernel_cls, CategoricalKernel):
patch_field(
MixedSingleTaskGPSurrogate,
"categorical_kernel",
kernels_api.AnyCategoricalKernel,
)

# Rebuild surrogate models
for cls in [
SingleTaskGPSurrogate,
MultiTaskGPSurrogate,
TanimotoGPSurrogate,
MixedSingleTaskGPSurrogate,
]:
cls.model_rebuild(force=True)

# Rebuild BotorchSurrogates
BotorchSurrogates.model_rebuild(force=True)


def register_kernel(data_model_cls: type) -> None:
"""Register a custom kernel type so it is accepted in AnyKernel fields.

This appends the type to the internal registry, rebuilds the
``AnyKernel`` union, and calls ``model_rebuild`` on all dependent
Pydantic models (aggregation kernels, surrogates) so that the new
type is accepted.

If the type is a subclass of ``ContinuousKernel`` or ``CategoricalKernel``,
it is also added to the corresponding sub-category union
(``AnyContinuousKernel`` / ``AnyCategoricalKernel``).

Args:
data_model_cls: A concrete subclass of ``Kernel``.
"""
import bofire.data_models.kernels.api as kernels_api
from bofire.data_models.kernels.categorical import CategoricalKernel
from bofire.data_models.kernels.continuous import ContinuousKernel

if data_model_cls in kernels_api._KERNEL_TYPES:
return
kernels_api._KERNEL_TYPES.append(data_model_cls)
kernels_api.AnyKernel = Union[tuple(kernels_api._KERNEL_TYPES)]

# Auto-detect sub-category from base class
if issubclass(data_model_cls, ContinuousKernel):
kernels_api._CONTINUOUS_KERNEL_TYPES.append(data_model_cls)
kernels_api.AnyContinuousKernel = Union[
tuple(kernels_api._CONTINUOUS_KERNEL_TYPES)
]
elif issubclass(data_model_cls, CategoricalKernel):
kernels_api._CATEGORICAL_KERNEL_TYPES.append(data_model_cls)
kernels_api.AnyCategoricalKernel = Union[
tuple(kernels_api._CATEGORICAL_KERNEL_TYPES)
]

_rebuild_dependent_models(data_model_cls)
13 changes: 10 additions & 3 deletions bofire/data_models/kernels/api.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need _rebuild_dependent_strategy similar to rebuild_dependent_models if we want full dynamic Pydantic acceptance of new strategy data-model types?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Should be fixed now. My Agent was overlooking this.

Copy link
Copy Markdown
Contributor

@bertiqwerty bertiqwerty Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Should be fixed now. My Agent was overlooking this. (last week)

Haha. Nice try blaming the agent. :D

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union

from bofire.data_models.kernels._register import register_kernel # noqa: F401
from bofire.data_models.kernels.aggregation import (
AdditiveKernel,
MultiplicativeKernel,
Expand Down Expand Up @@ -40,7 +41,7 @@
AggregationKernel,
]

AnyContinuousKernel = Union[
_CONTINUOUS_KERNEL_TYPES: list[type[ContinuousKernel]] = [
MaternKernel,
LinearKernel,
PolynomialKernel,
Expand All @@ -49,15 +50,19 @@
InfiniteWidthBNNKernel,
]

AnyCategoricalKernel = Union[
AnyContinuousKernel = Union[tuple(_CONTINUOUS_KERNEL_TYPES)]

_CATEGORICAL_KERNEL_TYPES: list[type[CategoricalKernel]] = [
HammingDistanceKernel,
IndexKernel,
PositiveIndexKernel,
]

AnyCategoricalKernel = Union[tuple(_CATEGORICAL_KERNEL_TYPES)]

AnyMolecularKernel = TanimotoKernel

AnyKernel = Union[
_KERNEL_TYPES: list[type[Kernel]] = [
AdditiveKernel,
MultiplicativeKernel,
PolynomialFeatureInteractionKernel,
Expand All @@ -75,3 +80,5 @@
WassersteinKernel,
WedgeKernel,
]

AnyKernel = Union[tuple(_KERNEL_TYPES)]
Loading
Loading