-
Notifications
You must be signed in to change notification settings - Fork 46
Add register() functions for strategies and surrogates #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6eaf66a
228a1df
23d09ec
216f686
1c1fe82
98b8311
2658f6e
93b40d5
3ed1a02
720bb6a
3eeca14
8da8684
eba56de
954851c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| """Shared utilities for dynamic Pydantic model registration.""" | ||
|
|
||
| import typing | ||
| from collections.abc import Sequence | ||
| from typing import Optional, Union | ||
|
|
||
|
|
||
| def patch_field(model_cls: type, field_name: str, new_union: type) -> None: | ||
| """Patch a Pydantic model field annotation with a new union type. | ||
|
|
||
| Handles three annotation forms: | ||
| - ``Union[A, B, ...]`` — replaced with *new_union* | ||
| - ``Optional[Union[A, B, ...]]`` — wrapped as ``Optional[new_union]`` | ||
| - ``Sequence[Union[A, B, ...]]`` — wrapped as ``Sequence[new_union]`` | ||
| """ | ||
| old = model_cls.model_fields[field_name].annotation | ||
| args = typing.get_args(old) | ||
|
|
||
| if not args: | ||
| new = new_union | ||
| elif type(None) in args: | ||
| # Optional[X] is Union[X, None] | ||
| new = Optional[new_union] | ||
| elif typing.get_origin(old) in (list, Sequence): | ||
| # Sequence[Union[...]] or list[Union[...]] | ||
| new = Sequence[new_union] | ||
| else: | ||
| new = new_union | ||
|
|
||
| model_cls.__annotations__[field_name] = new | ||
| model_cls.model_fields[field_name].annotation = new | ||
|
|
||
|
|
||
| def append_to_union_field(model_cls: type, field_name: str, new_type: type) -> None: | ||
| """Append a type to the union inside a model field annotation. | ||
|
|
||
| Detects the annotation structure (plain ``Union``, ``Optional[Union]``, | ||
| or ``Sequence[Union]``) and appends *new_type* to the inner union. | ||
| """ | ||
| old = model_cls.model_fields[field_name].annotation | ||
| origin = typing.get_origin(old) | ||
|
|
||
| if origin in (list, Sequence): | ||
| # Sequence[Union[A, B, ...]] → Sequence[Union[A, B, ..., new_type]] | ||
| inner = typing.get_args(old)[0] | ||
| inner_args = typing.get_args(inner) | ||
| if new_type not in inner_args: | ||
| new_inner = Union[tuple(list(inner_args) + [new_type])] | ||
| new_ann = Sequence[new_inner] | ||
| model_cls.__annotations__[field_name] = new_ann | ||
| model_cls.model_fields[field_name].annotation = new_ann | ||
| else: | ||
| # Union[A, B, ...] or Optional[Union[A, B, ...]] | ||
| args = typing.get_args(old) | ||
| has_none = type(None) in args | ||
| non_none = [a for a in args if a is not type(None)] | ||
| if new_type not in non_none: | ||
| non_none.append(new_type) | ||
| new_union = Union[tuple(non_none)] | ||
| new_ann = Optional[new_union] if has_none else new_union | ||
| model_cls.__annotations__[field_name] = new_ann | ||
| model_cls.model_fields[field_name].annotation = new_ann |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| from typing import Union | ||
| import typing | ||
| from collections.abc import Sequence | ||
| from typing import List, Type, Union | ||
|
|
||
| from bofire.data_models.features.categorical import CategoricalInput, CategoricalOutput | ||
| from bofire.data_models.features.continuous import ContinuousInput, ContinuousOutput | ||
|
|
@@ -65,11 +67,43 @@ | |
|
|
||
| AnyOutput = Union[ContinuousOutput, CategoricalOutput] | ||
|
|
||
| AnyEngineeredFeature = Union[ | ||
| _ENGINEERED_FEATURE_TYPES: List[Type[EngineeredFeature]] = [ | ||
| SumFeature, | ||
| MeanFeature, | ||
| WeightedSumFeature, | ||
| MolecularWeightedSumFeature, | ||
| ProductFeature, | ||
| CloneFeature, | ||
| ] | ||
|
|
||
| AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)] | ||
|
|
||
|
|
||
| def register_engineered_feature(data_model_cls: Type[EngineeredFeature]) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe don't clutter the |
||
| """Register a custom engineered feature type so it is accepted in EngineeredFeatures. | ||
|
|
||
| This appends the type to the internal registry, rebuilds the | ||
| ``AnyEngineeredFeature`` union, and calls ``model_rebuild`` on | ||
| ``EngineeredFeatures`` so that Pydantic accepts the new type. | ||
|
|
||
| Args: | ||
| data_model_cls: A concrete subclass of ``EngineeredFeature``. | ||
| """ | ||
| global AnyEngineeredFeature | ||
| if data_model_cls in _ENGINEERED_FEATURE_TYPES: | ||
| return | ||
| _ENGINEERED_FEATURE_TYPES.append(data_model_cls) | ||
| AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)] | ||
|
|
||
| # Lazy import to avoid circular dependencies | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above. if you had separate file for the implementation, would the import still be circular? |
||
| from bofire.data_models.domain.features import EngineeredFeatures | ||
|
|
||
| # Patch the Sequence[Union[...]] annotation on EngineeredFeatures.features | ||
| old = EngineeredFeatures.model_fields["features"].annotation | ||
| inner_args = typing.get_args(typing.get_args(old)[0]) | ||
| if data_model_cls not in inner_args: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this peace of code looks fragile. but currently i don't have a better idea. hmm... at least if this breaks at some point only the registration functinality is affected, not the existing modules. |
||
| new_inner = Union[tuple(list(inner_args) + [data_model_cls])] | ||
| new_ann = Sequence[new_inner] | ||
| EngineeredFeatures.__annotations__["features"] = new_ann | ||
| EngineeredFeatures.model_fields["features"].annotation = new_ann | ||
| EngineeredFeatures.model_rebuild(force=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, correct. Should be fixed now. My Agent was overlooking this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Haha. Nice try blaming the agent. :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use
listinstead ofList. Don't importList.