Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions bofire/data_models/_register_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Shared utilities for dynamic Pydantic model registration."""

import typing
from collections.abc import Sequence
from typing import Optional, Union


def patch_field(model_cls: type, field_name: str, new_union: type) -> None:
"""Patch a Pydantic model field annotation with a new union type.

Handles three annotation forms:
- ``Union[A, B, ...]`` — replaced with *new_union*
- ``Optional[Union[A, B, ...]]`` — wrapped as ``Optional[new_union]``
- ``Sequence[Union[A, B, ...]]`` — wrapped as ``Sequence[new_union]``
"""
old = model_cls.model_fields[field_name].annotation
args = typing.get_args(old)

if not args:
new = new_union
elif type(None) in args:
# Optional[X] is Union[X, None]
new = Optional[new_union]
elif typing.get_origin(old) in (list, Sequence):
# Sequence[Union[...]] or list[Union[...]]
new = Sequence[new_union]
else:
new = new_union

model_cls.__annotations__[field_name] = new
model_cls.model_fields[field_name].annotation = new


def append_to_union_field(model_cls: type, field_name: str, new_type: type) -> None:
"""Append a type to the union inside a model field annotation.

Detects the annotation structure (plain ``Union``, ``Optional[Union]``,
or ``Sequence[Union]``) and appends *new_type* to the inner union.
"""
old = model_cls.model_fields[field_name].annotation
origin = typing.get_origin(old)

if origin in (list, Sequence):
# Sequence[Union[A, B, ...]] → Sequence[Union[A, B, ..., new_type]]
inner = typing.get_args(old)[0]
inner_args = typing.get_args(inner)
if new_type not in inner_args:
new_inner = Union[tuple(list(inner_args) + [new_type])]
new_ann = Sequence[new_inner]
model_cls.__annotations__[field_name] = new_ann
model_cls.model_fields[field_name].annotation = new_ann
else:
# Union[A, B, ...] or Optional[Union[A, B, ...]]
args = typing.get_args(old)
has_none = type(None) in args
non_none = [a for a in args if a is not type(None)]
if new_type not in non_none:
non_none.append(new_type)
new_union = Union[tuple(non_none)]
new_ann = Optional[new_union] if has_none else new_union
model_cls.__annotations__[field_name] = new_ann
model_cls.model_fields[field_name].annotation = new_ann
38 changes: 36 additions & 2 deletions bofire/data_models/features/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Union
import typing
from collections.abc import Sequence
from typing import List, Type, Union
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use list instead of List. Don't import List.


from bofire.data_models.features.categorical import CategoricalInput, CategoricalOutput
from bofire.data_models.features.continuous import ContinuousInput, ContinuousOutput
Expand Down Expand Up @@ -65,11 +67,43 @@

AnyOutput = Union[ContinuousOutput, CategoricalOutput]

AnyEngineeredFeature = Union[
_ENGINEERED_FEATURE_TYPES: List[Type[EngineeredFeature]] = [
SumFeature,
MeanFeature,
WeightedSumFeature,
MolecularWeightedSumFeature,
ProductFeature,
CloneFeature,
]

AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)]


def register_engineered_feature(data_model_cls: Type[EngineeredFeature]) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe don't clutter the api.py with implementations? Create a separate file and import the function here?

"""Register a custom engineered feature type so it is accepted in EngineeredFeatures.

This appends the type to the internal registry, rebuilds the
``AnyEngineeredFeature`` union, and calls ``model_rebuild`` on
``EngineeredFeatures`` so that Pydantic accepts the new type.

Args:
data_model_cls: A concrete subclass of ``EngineeredFeature``.
"""
global AnyEngineeredFeature
if data_model_cls in _ENGINEERED_FEATURE_TYPES:
return
_ENGINEERED_FEATURE_TYPES.append(data_model_cls)
AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)]

# Lazy import to avoid circular dependencies
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above. if you had separate file for the implementation, would the import still be circular?

from bofire.data_models.domain.features import EngineeredFeatures

# Patch the Sequence[Union[...]] annotation on EngineeredFeatures.features
old = EngineeredFeatures.model_fields["features"].annotation
inner_args = typing.get_args(typing.get_args(old)[0])
if data_model_cls not in inner_args:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this peace of code looks fragile. but currently i don't have a better idea. hmm... at least if this breaks at some point only the registration functinality is affected, not the existing modules.

new_inner = Union[tuple(list(inner_args) + [data_model_cls])]
new_ann = Sequence[new_inner]
EngineeredFeatures.__annotations__["features"] = new_ann
EngineeredFeatures.model_fields["features"].annotation = new_ann
EngineeredFeatures.model_rebuild(force=True)
107 changes: 103 additions & 4 deletions bofire/data_models/kernels/api.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need _rebuild_dependent_strategy similar to rebuild_dependent_models if we want full dynamic Pydantic acceptance of new strategy data-model types?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Should be fixed now. My Agent was overlooking this.

Copy link
Copy Markdown
Contributor

@bertiqwerty bertiqwerty Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Should be fixed now. My Agent was overlooking this. (last week)

Haha. Nice try blaming the agent. :D

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import List, Type, Union

from bofire.data_models.kernels.aggregation import (
AdditiveKernel,
Expand Down Expand Up @@ -40,7 +40,7 @@
AggregationKernel,
]

AnyContinuousKernel = Union[
_CONTINUOUS_KERNEL_TYPES: List[Type[ContinuousKernel]] = [
MaternKernel,
LinearKernel,
PolynomialKernel,
Expand All @@ -49,15 +49,19 @@
InfiniteWidthBNNKernel,
]

AnyCategoricalKernel = Union[
AnyContinuousKernel = Union[tuple(_CONTINUOUS_KERNEL_TYPES)]

_CATEGORICAL_KERNEL_TYPES: List[Type[CategoricalKernel]] = [
HammingDistanceKernel,
IndexKernel,
PositiveIndexKernel,
]

AnyCategoricalKernel = Union[tuple(_CATEGORICAL_KERNEL_TYPES)]

AnyMolecularKernel = TanimotoKernel

AnyKernel = Union[
_KERNEL_TYPES: List[Type[Kernel]] = [
AdditiveKernel,
MultiplicativeKernel,
PolynomialFeatureInteractionKernel,
Expand All @@ -75,3 +79,98 @@
WassersteinKernel,
WedgeKernel,
]

AnyKernel = Union[tuple(_KERNEL_TYPES)]


def _rebuild_dependent_models(new_kernel_cls: Type[Kernel]) -> None:
"""Rebuild all Pydantic models whose fields reference kernel unions."""
from bofire.data_models._register_utils import append_to_union_field, patch_field
from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates
from bofire.data_models.surrogates.mixed_single_task_gp import (
MixedSingleTaskGPSurrogate,
)
from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate
Comment thread
jduerholt marked this conversation as resolved.
Outdated

# Add new kernel type to aggregation kernel inline unions
# (handles both Sequence[Union[...]] and plain Union[...])
for model_cls, field_name in [
(AdditiveKernel, "kernels"),
(MultiplicativeKernel, "kernels"),
(PolynomialFeatureInteractionKernel, "kernels"),
(ScaleKernel, "base_kernel"),
]:
append_to_union_field(model_cls, field_name, new_kernel_cls)

# Rebuild aggregation kernels
for cls in [
AdditiveKernel,
MultiplicativeKernel,
ScaleKernel,
PolynomialFeatureInteractionKernel,
]:
cls.model_rebuild(force=True)

# Patch AnyKernel fields on surrogate models
for model_cls, field_name in [
(SingleTaskGPSurrogate, "kernel"),
(MultiTaskGPSurrogate, "kernel"),
(TanimotoGPSurrogate, "kernel"),
]:
patch_field(model_cls, field_name, AnyKernel)

# Patch sub-category kernel fields if the new type is a subclass
if issubclass(new_kernel_cls, ContinuousKernel):
patch_field(
MixedSingleTaskGPSurrogate, "continuous_kernel", AnyContinuousKernel
)
if issubclass(new_kernel_cls, CategoricalKernel):
patch_field(
MixedSingleTaskGPSurrogate, "categorical_kernel", AnyCategoricalKernel
)

# Rebuild surrogate models
for cls in [
SingleTaskGPSurrogate,
MultiTaskGPSurrogate,
TanimotoGPSurrogate,
MixedSingleTaskGPSurrogate,
]:
cls.model_rebuild(force=True)

# Rebuild BotorchSurrogates
BotorchSurrogates.model_rebuild(force=True)


def register_kernel(data_model_cls: Type[Kernel]) -> None:
"""Register a custom kernel type so it is accepted in AnyKernel fields.

This appends the type to the internal registry, rebuilds the
``AnyKernel`` union, and calls ``model_rebuild`` on all dependent
Pydantic models (aggregation kernels, surrogates) so that the new
type is accepted.

If the type is a subclass of ``ContinuousKernel`` or ``CategoricalKernel``,
it is also added to the corresponding sub-category union
(``AnyContinuousKernel`` / ``AnyCategoricalKernel``).

Args:
data_model_cls: A concrete subclass of ``Kernel``.
"""
global AnyKernel, AnyContinuousKernel, AnyCategoricalKernel
if data_model_cls in _KERNEL_TYPES:
return
_KERNEL_TYPES.append(data_model_cls)
AnyKernel = Union[tuple(_KERNEL_TYPES)]

# Auto-detect sub-category from base class
if issubclass(data_model_cls, ContinuousKernel):
_CONTINUOUS_KERNEL_TYPES.append(data_model_cls)
AnyContinuousKernel = Union[tuple(_CONTINUOUS_KERNEL_TYPES)]
elif issubclass(data_model_cls, CategoricalKernel):
_CATEGORICAL_KERNEL_TYPES.append(data_model_cls)
AnyCategoricalKernel = Union[tuple(_CATEGORICAL_KERNEL_TYPES)]

_rebuild_dependent_models(data_model_cls)
Loading