-
Notifications
You must be signed in to change notification settings - Fork 45
Add register() functions for strategies and surrogates #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
6eaf66a
Add register() functions for strategies and surrogates
jduerholt 228a1df
decorator syntax
jduerholt 23d09ec
Add register() for kernels, priors, and engineered features with dyna…
jduerholt 216f686
Use @register decorator for built-in kernels, priors, and engineered …
jduerholt 1c1fe82
Address PR review comments: shared utils, dynamic kernel sub-unions, …
jduerholt 98b8311
Fix WedgeKernel.base_kernel not updated by register_kernel
jduerholt 2658f6e
Merge main into feature/register
jduerholt 93b40d5
Suppress ty invalid-type-form for dynamic Union[tuple()] patterns
jduerholt 3ed1a02
Add user guide for registering custom types
jduerholt 720bb6a
Suppress ty unresolved-attribute errors in _register_utils.py
jduerholt 3eeca14
Move ty: ignore comments to first line of multiline expressions
jduerholt 8da8684
implement behrangs comments
jduerholt eba56de
Clean up _register imports: move to top of file, remove noqa: E402 hacks
jduerholt 954851c
Use append_to_union_field in features/_register.py
jduerholt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """Shared utilities for dynamic Pydantic model registration.""" | ||
|
|
||
| import typing | ||
| from collections.abc import Sequence | ||
| from typing import Optional, Union | ||
|
|
||
|
|
||
| def patch_field(model_cls: type, field_name: str, new_union: type) -> None: | ||
| """Patch a Pydantic model field annotation with a new union type. | ||
|
|
||
| Handles three annotation forms: | ||
| - ``Union[A, B, ...]`` — replaced with *new_union* | ||
| - ``Optional[Union[A, B, ...]]`` — wrapped as ``Optional[new_union]`` | ||
| - ``Sequence[Union[A, B, ...]]`` — wrapped as ``Sequence[new_union]`` | ||
| """ | ||
| old = model_cls.model_fields[ # ty: ignore[unresolved-attribute] | ||
| field_name | ||
| ].annotation | ||
| args = typing.get_args(old) | ||
|
|
||
| if not args: | ||
| new = new_union | ||
| elif type(None) in args: | ||
| # Optional[X] is Union[X, None] | ||
| new = Optional[new_union] | ||
| elif typing.get_origin(old) in (list, Sequence): | ||
| # Sequence[Union[...]] or list[Union[...]] | ||
| new = Sequence[new_union] | ||
| else: | ||
| new = new_union | ||
|
|
||
| model_cls.__annotations__[field_name] = new | ||
| model_cls.model_fields[ # ty: ignore[unresolved-attribute] | ||
| field_name | ||
| ].annotation = new | ||
|
|
||
|
|
||
| def append_to_union_field(model_cls: type, field_name: str, new_type: type) -> None: | ||
| """Append a type to the union inside a model field annotation. | ||
|
|
||
| Detects the annotation structure (plain ``Union``, ``Optional[Union]``, | ||
| or ``Sequence[Union]``) and appends *new_type* to the inner union. | ||
| """ | ||
| old = model_cls.model_fields[ # ty: ignore[unresolved-attribute] | ||
| field_name | ||
| ].annotation | ||
| origin = typing.get_origin(old) | ||
|
|
||
| if origin in (list, Sequence): | ||
| # Sequence[Union[A, B, ...]] → Sequence[Union[A, B, ..., new_type]] | ||
| inner = typing.get_args(old)[0] | ||
| inner_args = typing.get_args(inner) | ||
| if new_type not in inner_args: | ||
| new_inner = Union[tuple(list(inner_args) + [new_type])] | ||
| new_ann = Sequence[new_inner] | ||
| model_cls.__annotations__[field_name] = new_ann | ||
| model_cls.model_fields[ # ty: ignore[unresolved-attribute] | ||
| field_name | ||
| ].annotation = new_ann | ||
| else: | ||
| # Union[A, B, ...] or Optional[Union[A, B, ...]] | ||
| args = typing.get_args(old) | ||
| has_none = type(None) in args | ||
| non_none = [a for a in args if a is not type(None)] | ||
| if new_type not in non_none: | ||
| non_none.append(new_type) | ||
| new_union = Union[tuple(non_none)] | ||
| new_ann = Optional[new_union] if has_none else new_union | ||
| model_cls.__annotations__[field_name] = new_ann | ||
| model_cls.model_fields[ # ty: ignore[unresolved-attribute] | ||
| field_name | ||
| ].annotation = new_ann |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| """Registration utilities for custom engineered feature types.""" | ||
|
|
||
| from typing import Union | ||
|
|
||
|
|
||
| def register_engineered_feature(data_model_cls: type) -> None: | ||
| """Register a custom engineered feature type so it is accepted in EngineeredFeatures. | ||
|
|
||
| This appends the type to the internal registry, rebuilds the | ||
| ``AnyEngineeredFeature`` union, and calls ``model_rebuild`` on | ||
| ``EngineeredFeatures`` so that Pydantic accepts the new type. | ||
|
|
||
| Args: | ||
| data_model_cls: A concrete subclass of ``EngineeredFeature``. | ||
| """ | ||
| import bofire.data_models.features.api as features_api | ||
|
|
||
| if data_model_cls in features_api._ENGINEERED_FEATURE_TYPES: | ||
| return | ||
| features_api._ENGINEERED_FEATURE_TYPES.append(data_model_cls) | ||
| features_api.AnyEngineeredFeature = Union[ | ||
| tuple(features_api._ENGINEERED_FEATURE_TYPES) | ||
| ] | ||
|
|
||
| from bofire.data_models._register_utils import append_to_union_field | ||
| from bofire.data_models.domain.features import EngineeredFeatures | ||
|
|
||
| append_to_union_field(EngineeredFeatures, "features", data_model_cls) | ||
| EngineeredFeatures.model_rebuild(force=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| """Registration utilities for custom kernel types.""" | ||
|
|
||
| from typing import Union | ||
|
|
||
|
|
||
| def _rebuild_dependent_models(new_kernel_cls: type) -> None: | ||
| """Rebuild all Pydantic models whose fields reference kernel unions.""" | ||
| import bofire.data_models.kernels.api as kernels_api | ||
| from bofire.data_models._register_utils import append_to_union_field, patch_field | ||
| from bofire.data_models.kernels.aggregation import ( | ||
| AdditiveKernel, | ||
| MultiplicativeKernel, | ||
| PolynomialFeatureInteractionKernel, | ||
| ScaleKernel, | ||
| ) | ||
| from bofire.data_models.kernels.categorical import CategoricalKernel | ||
| from bofire.data_models.kernels.conditional import ( | ||
| ConditionalEmbeddingKernel, | ||
| WedgeKernel, | ||
| ) | ||
| from bofire.data_models.kernels.continuous import ContinuousKernel | ||
| from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates | ||
| from bofire.data_models.surrogates.mixed_single_task_gp import ( | ||
| MixedSingleTaskGPSurrogate, | ||
| ) | ||
| from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate | ||
| from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate | ||
| from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate | ||
|
|
||
| # Add new kernel type to aggregation kernel inline unions | ||
| # (handles both Sequence[Union[...]] and plain Union[...]) | ||
| for model_cls, field_name in [ | ||
| (AdditiveKernel, "kernels"), | ||
| (MultiplicativeKernel, "kernels"), | ||
| (PolynomialFeatureInteractionKernel, "kernels"), | ||
| (ScaleKernel, "base_kernel"), | ||
| (ConditionalEmbeddingKernel, "base_kernel"), | ||
| (WedgeKernel, "base_kernel"), | ||
| ]: | ||
| append_to_union_field(model_cls, field_name, new_kernel_cls) | ||
|
|
||
| # Rebuild aggregation and conditional kernels | ||
| for cls in [ | ||
| AdditiveKernel, | ||
| MultiplicativeKernel, | ||
| ScaleKernel, | ||
| PolynomialFeatureInteractionKernel, | ||
| ConditionalEmbeddingKernel, | ||
| WedgeKernel, | ||
| ]: | ||
| cls.model_rebuild(force=True) | ||
|
|
||
| # Patch AnyKernel fields on surrogate models | ||
| for model_cls, field_name in [ | ||
| (SingleTaskGPSurrogate, "kernel"), | ||
| (MultiTaskGPSurrogate, "kernel"), | ||
| (TanimotoGPSurrogate, "kernel"), | ||
| ]: | ||
| patch_field(model_cls, field_name, kernels_api.AnyKernel) | ||
|
|
||
| # Patch sub-category kernel fields if the new type is a subclass | ||
| if issubclass(new_kernel_cls, ContinuousKernel): | ||
| patch_field( | ||
| MixedSingleTaskGPSurrogate, | ||
| "continuous_kernel", | ||
| kernels_api.AnyContinuousKernel, | ||
| ) | ||
| if issubclass(new_kernel_cls, CategoricalKernel): | ||
| patch_field( | ||
| MixedSingleTaskGPSurrogate, | ||
| "categorical_kernel", | ||
| kernels_api.AnyCategoricalKernel, | ||
| ) | ||
|
|
||
| # Rebuild surrogate models | ||
| for cls in [ | ||
| SingleTaskGPSurrogate, | ||
| MultiTaskGPSurrogate, | ||
| TanimotoGPSurrogate, | ||
| MixedSingleTaskGPSurrogate, | ||
| ]: | ||
| cls.model_rebuild(force=True) | ||
|
|
||
| # Rebuild BotorchSurrogates | ||
| BotorchSurrogates.model_rebuild(force=True) | ||
|
|
||
|
|
||
| def register_kernel(data_model_cls: type) -> None: | ||
| """Register a custom kernel type so it is accepted in AnyKernel fields. | ||
|
|
||
| This appends the type to the internal registry, rebuilds the | ||
| ``AnyKernel`` union, and calls ``model_rebuild`` on all dependent | ||
| Pydantic models (aggregation kernels, surrogates) so that the new | ||
| type is accepted. | ||
|
|
||
| If the type is a subclass of ``ContinuousKernel`` or ``CategoricalKernel``, | ||
| it is also added to the corresponding sub-category union | ||
| (``AnyContinuousKernel`` / ``AnyCategoricalKernel``). | ||
|
|
||
| Args: | ||
| data_model_cls: A concrete subclass of ``Kernel``. | ||
| """ | ||
| import bofire.data_models.kernels.api as kernels_api | ||
| from bofire.data_models.kernels.categorical import CategoricalKernel | ||
| from bofire.data_models.kernels.continuous import ContinuousKernel | ||
|
|
||
| if data_model_cls in kernels_api._KERNEL_TYPES: | ||
| return | ||
| kernels_api._KERNEL_TYPES.append(data_model_cls) | ||
| kernels_api.AnyKernel = Union[tuple(kernels_api._KERNEL_TYPES)] | ||
|
|
||
| # Auto-detect sub-category from base class | ||
| if issubclass(data_model_cls, ContinuousKernel): | ||
| kernels_api._CONTINUOUS_KERNEL_TYPES.append(data_model_cls) | ||
| kernels_api.AnyContinuousKernel = Union[ | ||
| tuple(kernels_api._CONTINUOUS_KERNEL_TYPES) | ||
| ] | ||
| elif issubclass(data_model_cls, CategoricalKernel): | ||
| kernels_api._CATEGORICAL_KERNEL_TYPES.append(data_model_cls) | ||
| kernels_api.AnyCategoricalKernel = Union[ | ||
| tuple(kernels_api._CATEGORICAL_KERNEL_TYPES) | ||
| ] | ||
|
|
||
| _rebuild_dependent_models(data_model_cls) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need
_rebuild_dependent_strategysimilar torebuild_dependent_modelsif we want full dynamic Pydantic acceptance of new strategy data-model types?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, correct. Should be fixed now. My Agent was overlooking this.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha. Nice try blaming the agent. :D