Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions bofire/data_models/features/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Union
import typing
from collections.abc import Sequence
from typing import List, Type, Union
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use list instead of List. Don't import List.


from bofire.data_models.features.categorical import CategoricalInput, CategoricalOutput
from bofire.data_models.features.continuous import ContinuousInput, ContinuousOutput
Expand Down Expand Up @@ -65,11 +67,43 @@

AnyOutput = Union[ContinuousOutput, CategoricalOutput]

AnyEngineeredFeature = Union[
_ENGINEERED_FEATURE_TYPES: List[Type[EngineeredFeature]] = [
SumFeature,
MeanFeature,
WeightedSumFeature,
MolecularWeightedSumFeature,
ProductFeature,
CloneFeature,
]

AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)]


def register_engineered_feature(data_model_cls: Type[EngineeredFeature]) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe don't clutter the api.py with implementations? Create a separate file and import the function here?

"""Register a custom engineered feature type so it is accepted in EngineeredFeatures.

This appends the type to the internal registry, rebuilds the
``AnyEngineeredFeature`` union, and calls ``model_rebuild`` on
``EngineeredFeatures`` so that Pydantic accepts the new type.

Args:
data_model_cls: A concrete subclass of ``EngineeredFeature``.
"""
global AnyEngineeredFeature
if data_model_cls in _ENGINEERED_FEATURE_TYPES:
return
_ENGINEERED_FEATURE_TYPES.append(data_model_cls)
AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)]

# Lazy import to avoid circular dependencies
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above. if you had separate file for the implementation, would the import still be circular?

from bofire.data_models.domain.features import EngineeredFeatures

# Patch the Sequence[Union[...]] annotation on EngineeredFeatures.features
old = EngineeredFeatures.model_fields["features"].annotation
inner_args = typing.get_args(typing.get_args(old)[0])
if data_model_cls not in inner_args:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this peace of code looks fragile. but currently i don't have a better idea. hmm... at least if this breaks at some point only the registration functinality is affected, not the existing modules.

new_inner = Union[tuple(list(inner_args) + [data_model_cls])]
new_ann = Sequence[new_inner]
EngineeredFeatures.__annotations__["features"] = new_ann
EngineeredFeatures.model_fields["features"].annotation = new_ann
EngineeredFeatures.model_rebuild(force=True)
85 changes: 83 additions & 2 deletions bofire/data_models/kernels/api.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need _rebuild_dependent_strategy similar to rebuild_dependent_models if we want full dynamic Pydantic acceptance of new strategy data-model types?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Should be fixed now. My Agent was overlooking this.

Copy link
Copy Markdown
Contributor

@bertiqwerty bertiqwerty Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Should be fixed now. My Agent was overlooking this. (last week)

Haha. Nice try blaming the agent. :D

Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Union
import typing
from collections.abc import Sequence
from typing import List, Type, Union

from bofire.data_models.kernels.aggregation import (
AdditiveKernel,
Expand Down Expand Up @@ -57,7 +59,7 @@

AnyMolecularKernel = TanimotoKernel

AnyKernel = Union[
_KERNEL_TYPES: List[Type[Kernel]] = [
AdditiveKernel,
MultiplicativeKernel,
PolynomialFeatureInteractionKernel,
Expand All @@ -75,3 +77,82 @@
WassersteinKernel,
WedgeKernel,
]

AnyKernel = Union[tuple(_KERNEL_TYPES)]


def _rebuild_dependent_models(new_kernel_cls: Type[Kernel]) -> None:
"""Rebuild all Pydantic models whose fields reference AnyKernel."""
from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates
from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate
Comment thread
jduerholt marked this conversation as resolved.
Outdated

# Add new kernel type to aggregation kernel inline unions
for model_cls, field_name in [
(AdditiveKernel, "kernels"),
(MultiplicativeKernel, "kernels"),
(PolynomialFeatureInteractionKernel, "kernels"),
]:
old = model_cls.model_fields[field_name].annotation
# Annotation is Sequence[Union[...]]
inner = typing.get_args(old)[0]
inner_args = typing.get_args(inner)
if new_kernel_cls not in inner_args:
new_inner = Union[tuple(list(inner_args) + [new_kernel_cls])]
new_ann = Sequence[new_inner]
model_cls.__annotations__[field_name] = new_ann
model_cls.model_fields[field_name].annotation = new_ann

# ScaleKernel.base_kernel is Union[...] (not Sequence)
old = ScaleKernel.model_fields["base_kernel"].annotation
Comment thread
jduerholt marked this conversation as resolved.
Outdated
args = typing.get_args(old)
if new_kernel_cls not in args:
new_ann = Union[tuple(list(args) + [new_kernel_cls])]
ScaleKernel.__annotations__["base_kernel"] = new_ann
ScaleKernel.model_fields["base_kernel"].annotation = new_ann

# Rebuild aggregation kernels
for cls in [
AdditiveKernel,
MultiplicativeKernel,
ScaleKernel,
PolynomialFeatureInteractionKernel,
]:
cls.model_rebuild(force=True)

# Patch AnyKernel fields on surrogate models
any_kernel = AnyKernel
for model_cls, field_name in [
(SingleTaskGPSurrogate, "kernel"),
(MultiTaskGPSurrogate, "kernel"),
(TanimotoGPSurrogate, "kernel"),
]:
model_cls.__annotations__[field_name] = any_kernel
model_cls.model_fields[field_name].annotation = any_kernel

# Rebuild surrogate models
for cls in [SingleTaskGPSurrogate, MultiTaskGPSurrogate, TanimotoGPSurrogate]:
cls.model_rebuild(force=True)

# Rebuild BotorchSurrogates
BotorchSurrogates.model_rebuild(force=True)


def register_kernel(data_model_cls: Type[Kernel]) -> None:
"""Register a custom kernel type so it is accepted in AnyKernel fields.

This appends the type to the internal registry, rebuilds the
``AnyKernel`` union, and calls ``model_rebuild`` on all dependent
Pydantic models (aggregation kernels, surrogates) so that the new
type is accepted.

Args:
data_model_cls: A concrete subclass of ``Kernel``.
"""
global AnyKernel
if data_model_cls in _KERNEL_TYPES:
return
_KERNEL_TYPES.append(data_model_cls)
AnyKernel = Union[tuple(_KERNEL_TYPES)]
_rebuild_dependent_models(data_model_cls)
189 changes: 185 additions & 4 deletions bofire/data_models/priors/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Union
from typing import List, Optional, Type, Union

from bofire.data_models.priors.constraint import (
GreaterThan,
Expand All @@ -25,18 +25,199 @@
AbstractPrior = Prior
AbstractPriorConstraint = Union[PriorConstraint, Interval]

AnyPrior = Union[
_PRIOR_TYPES: List[Type[Prior]] = [
GammaPrior,
NormalPrior,
LKJPrior,
LogNormalPrior,
DimensionalityScaledLogNormalPrior,
]

AnyPriorConstraint = Union[
NonTransformedInterval, LogTransformedInterval, Positive, GreaterThan, LessThan
AnyPrior = Union[tuple(_PRIOR_TYPES)]

_PRIOR_CONSTRAINT_TYPES: List[Type] = [
NonTransformedInterval,
LogTransformedInterval,
Positive,
GreaterThan,
LessThan,
]

AnyPriorConstraint = Union[tuple(_PRIOR_CONSTRAINT_TYPES)]


def _patch_field(model_cls: type, field_name: str, new_union: type) -> None:
Comment thread
jduerholt marked this conversation as resolved.
Outdated
"""Patch a Pydantic model field annotation, preserving Optional wrapping."""
import typing

old = model_cls.model_fields[field_name].annotation
args = typing.get_args(old)
if type(None) in args:
new = Optional[new_union]
else:
new = new_union
model_cls.__annotations__[field_name] = new
model_cls.model_fields[field_name].annotation = new


def _rebuild_dependent_models() -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer a separate file to keep the api.py as clean as possible.

"""Rebuild all Pydantic models whose fields reference AnyPrior or AnyPriorConstraint."""
# Lazy imports to avoid circular dependencies
from bofire.data_models.kernels.aggregation import (
AdditiveKernel,
MultiplicativeKernel,
PolynomialFeatureInteractionKernel,
ScaleKernel,
)
from bofire.data_models.kernels.categorical import (
HammingDistanceKernel,
IndexKernel,
PositiveIndexKernel,
)
from bofire.data_models.kernels.conditional import WedgeKernel
from bofire.data_models.kernels.continuous import (
MaternKernel,
RBFKernel,
SphericalLinearKernel,
)
from bofire.data_models.kernels.shape import WassersteinKernel
from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates
from bofire.data_models.surrogates.linear import LinearSurrogate
from bofire.data_models.surrogates.mixed_single_task_gp import (
MixedSingleTaskGPSurrogate,
)
from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.data_models.surrogates.polynomial import PolynomialSurrogate
from bofire.data_models.surrogates.robust_single_task_gp import (
RobustSingleTaskGPSurrogate,
)
from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.data_models.surrogates.single_task_gp import (
SingleTaskGPHyperconfig,
SingleTaskGPSurrogate,
)
from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate

# Patch AnyPrior fields
for model_cls, field_name in [
Comment thread
jduerholt marked this conversation as resolved.
Outdated
(RBFKernel, "lengthscale_prior"),
(MaternKernel, "lengthscale_prior"),
(SphericalLinearKernel, "lengthscale_prior"),
(HammingDistanceKernel, "lengthscale_prior"),
(IndexKernel, "prior"),
(PositiveIndexKernel, "prior"),
(PositiveIndexKernel, "task_prior"),
(PositiveIndexKernel, "diag_prior"),
(WedgeKernel, "lengthscale_prior"),
(WedgeKernel, "angle_prior"),
(WedgeKernel, "radius_prior"),
(WassersteinKernel, "lengthscale_prior"),
(SingleTaskGPSurrogate, "noise_prior"),
(MultiTaskGPSurrogate, "noise_prior"),
(MixedSingleTaskGPSurrogate, "noise_prior"),
(TanimotoGPSurrogate, "noise_prior"),
(PiecewiseLinearGPSurrogate, "outputscale_prior"),
(PiecewiseLinearGPSurrogate, "noise_prior"),
(PolynomialSurrogate, "noise_prior"),
(LinearSurrogate, "noise_prior"),
(RobustSingleTaskGPSurrogate, "noise_prior"),
]:
_patch_field(model_cls, field_name, AnyPrior)

# Patch AnyPriorConstraint fields
for model_cls, field_name in [
(RBFKernel, "lengthscale_constraint"),
(MaternKernel, "lengthscale_constraint"),
(SphericalLinearKernel, "lengthscale_constraint"),
(HammingDistanceKernel, "lengthscale_constraint"),
(IndexKernel, "var_constraint"),
(PositiveIndexKernel, "var_constraint"),
(WedgeKernel, "lengthscale_constraint"),
(ScaleKernel, "outputscale_constraint"),
(SingleTaskGPHyperconfig, "lengthscale_constraint"),
(SingleTaskGPHyperconfig, "outputscale_constraint"),
]:
_patch_field(model_cls, field_name, AnyPriorConstraint)

# Rebuild in dependency order:
# 1. Leaf kernel models
for cls in [
RBFKernel,
MaternKernel,
SphericalLinearKernel,
HammingDistanceKernel,
IndexKernel,
PositiveIndexKernel,
WassersteinKernel,
]:
cls.model_rebuild(force=True)

# 2. Conditional kernels (reference leaf kernels)
WedgeKernel.model_rebuild(force=True)

# 3. Aggregation kernels (embed leaf kernel types)
for cls in [
AdditiveKernel,
MultiplicativeKernel,
ScaleKernel,
PolynomialFeatureInteractionKernel,
]:
cls.model_rebuild(force=True)

# 4. Surrogate models
SingleTaskGPHyperconfig.model_rebuild(force=True)
for cls in [
SingleTaskGPSurrogate,
MultiTaskGPSurrogate,
MixedSingleTaskGPSurrogate,
TanimotoGPSurrogate,
PiecewiseLinearGPSurrogate,
PolynomialSurrogate,
LinearSurrogate,
RobustSingleTaskGPSurrogate,
]:
cls.model_rebuild(force=True)

# 5. BotorchSurrogates
BotorchSurrogates.model_rebuild(force=True)


def register_prior(data_model_cls: Type[Prior]) -> None:
"""Register a custom prior type so it is accepted in AnyPrior fields.

This appends the type to the internal registry, rebuilds the
``AnyPrior`` union, and calls ``model_rebuild`` on all dependent
Pydantic models (kernels, surrogates) so that the new type is accepted.

Args:
data_model_cls: A concrete subclass of ``Prior``.
"""
global AnyPrior
if data_model_cls in _PRIOR_TYPES:
return
_PRIOR_TYPES.append(data_model_cls)
AnyPrior = Union[tuple(_PRIOR_TYPES)]
_rebuild_dependent_models()


def register_prior_constraint(data_model_cls: Type) -> None:
"""Register a custom prior constraint type so it is accepted in AnyPriorConstraint fields.

This appends the type to the internal registry, rebuilds the
``AnyPriorConstraint`` union, and calls ``model_rebuild`` on all
dependent Pydantic models.

Args:
data_model_cls: A concrete subclass of ``PriorConstraint`` or ``Interval``.
"""
global AnyPriorConstraint
if data_model_cls in _PRIOR_CONSTRAINT_TYPES:
return
_PRIOR_CONSTRAINT_TYPES.append(data_model_cls)
AnyPriorConstraint = Union[tuple(_PRIOR_CONSTRAINT_TYPES)]
_rebuild_dependent_models()


# these are priors that are generally applicable
# and do not depend on problem specific extra parameters
AnyGeneralPrior = Union[GammaPrior, NormalPrior, LKJPrior, LogNormalPrior]
Expand Down
1 change: 1 addition & 0 deletions bofire/data_models/surrogates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from bofire.data_models.surrogates.botorch_surrogates import (
AnyBotorchSurrogate,
BotorchSurrogates,
register_botorch_surrogate,
)
from bofire.data_models.surrogates.deterministic import (
CategoricalDeterministicSurrogate,
Expand Down
Loading
Loading