-
Notifications
You must be signed in to change notification settings - Fork 46
Add register() functions for strategies and surrogates #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
6eaf66a
228a1df
23d09ec
216f686
1c1fe82
98b8311
2658f6e
93b40d5
3ed1a02
720bb6a
3eeca14
8da8684
eba56de
954851c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| from typing import Union | ||
| import typing | ||
| from collections.abc import Sequence | ||
| from typing import List, Type, Union | ||
|
|
||
| from bofire.data_models.features.categorical import CategoricalInput, CategoricalOutput | ||
| from bofire.data_models.features.continuous import ContinuousInput, ContinuousOutput | ||
|
|
@@ -65,11 +67,43 @@ | |
|
|
||
| AnyOutput = Union[ContinuousOutput, CategoricalOutput] | ||
|
|
||
| AnyEngineeredFeature = Union[ | ||
| _ENGINEERED_FEATURE_TYPES: List[Type[EngineeredFeature]] = [ | ||
| SumFeature, | ||
| MeanFeature, | ||
| WeightedSumFeature, | ||
| MolecularWeightedSumFeature, | ||
| ProductFeature, | ||
| CloneFeature, | ||
| ] | ||
|
|
||
| AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)] | ||
|
|
||
|
|
||
| def register_engineered_feature(data_model_cls: Type[EngineeredFeature]) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe don't clutter the |
||
| """Register a custom engineered feature type so it is accepted in EngineeredFeatures. | ||
|
|
||
| This appends the type to the internal registry, rebuilds the | ||
| ``AnyEngineeredFeature`` union, and calls ``model_rebuild`` on | ||
| ``EngineeredFeatures`` so that Pydantic accepts the new type. | ||
|
|
||
| Args: | ||
| data_model_cls: A concrete subclass of ``EngineeredFeature``. | ||
| """ | ||
| global AnyEngineeredFeature | ||
| if data_model_cls in _ENGINEERED_FEATURE_TYPES: | ||
| return | ||
| _ENGINEERED_FEATURE_TYPES.append(data_model_cls) | ||
| AnyEngineeredFeature = Union[tuple(_ENGINEERED_FEATURE_TYPES)] | ||
|
|
||
| # Lazy import to avoid circular dependencies | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above. if you had separate file for the implementation, would the import still be circular? |
||
| from bofire.data_models.domain.features import EngineeredFeatures | ||
|
|
||
| # Patch the Sequence[Union[...]] annotation on EngineeredFeatures.features | ||
| old = EngineeredFeatures.model_fields["features"].annotation | ||
| inner_args = typing.get_args(typing.get_args(old)[0]) | ||
| if data_model_cls not in inner_args: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this peace of code looks fragile. but currently i don't have a better idea. hmm... at least if this breaks at some point only the registration functinality is affected, not the existing modules. |
||
| new_inner = Union[tuple(list(inner_args) + [data_model_cls])] | ||
| new_ann = Sequence[new_inner] | ||
| EngineeredFeatures.__annotations__["features"] = new_ann | ||
| EngineeredFeatures.model_fields["features"].annotation = new_ann | ||
| EngineeredFeatures.model_rebuild(force=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, correct. Should be fixed now. My Agent was overlooking this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Haha. Nice try blaming the agent. :D |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| from functools import partial | ||
| from typing import Union | ||
| from typing import List, Optional, Type, Union | ||
|
|
||
| from bofire.data_models.priors.constraint import ( | ||
| GreaterThan, | ||
|
|
@@ -25,18 +25,199 @@ | |
| AbstractPrior = Prior | ||
| AbstractPriorConstraint = Union[PriorConstraint, Interval] | ||
|
|
||
| AnyPrior = Union[ | ||
| _PRIOR_TYPES: List[Type[Prior]] = [ | ||
| GammaPrior, | ||
| NormalPrior, | ||
| LKJPrior, | ||
| LogNormalPrior, | ||
| DimensionalityScaledLogNormalPrior, | ||
| ] | ||
|
|
||
| AnyPriorConstraint = Union[ | ||
| NonTransformedInterval, LogTransformedInterval, Positive, GreaterThan, LessThan | ||
| AnyPrior = Union[tuple(_PRIOR_TYPES)] | ||
|
|
||
| _PRIOR_CONSTRAINT_TYPES: List[Type] = [ | ||
| NonTransformedInterval, | ||
| LogTransformedInterval, | ||
| Positive, | ||
| GreaterThan, | ||
| LessThan, | ||
| ] | ||
|
|
||
| AnyPriorConstraint = Union[tuple(_PRIOR_CONSTRAINT_TYPES)] | ||
|
|
||
|
|
||
| def _patch_field(model_cls: type, field_name: str, new_union: type) -> None: | ||
|
jduerholt marked this conversation as resolved.
Outdated
|
||
| """Patch a Pydantic model field annotation, preserving Optional wrapping.""" | ||
| import typing | ||
|
|
||
| old = model_cls.model_fields[field_name].annotation | ||
| args = typing.get_args(old) | ||
| if type(None) in args: | ||
| new = Optional[new_union] | ||
| else: | ||
| new = new_union | ||
| model_cls.__annotations__[field_name] = new | ||
| model_cls.model_fields[field_name].annotation = new | ||
|
|
||
|
|
||
| def _rebuild_dependent_models() -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer a separate file to keep the |
||
| """Rebuild all Pydantic models whose fields reference AnyPrior or AnyPriorConstraint.""" | ||
| # Lazy imports to avoid circular dependencies | ||
| from bofire.data_models.kernels.aggregation import ( | ||
| AdditiveKernel, | ||
| MultiplicativeKernel, | ||
| PolynomialFeatureInteractionKernel, | ||
| ScaleKernel, | ||
| ) | ||
| from bofire.data_models.kernels.categorical import ( | ||
| HammingDistanceKernel, | ||
| IndexKernel, | ||
| PositiveIndexKernel, | ||
| ) | ||
| from bofire.data_models.kernels.conditional import WedgeKernel | ||
| from bofire.data_models.kernels.continuous import ( | ||
| MaternKernel, | ||
| RBFKernel, | ||
| SphericalLinearKernel, | ||
| ) | ||
| from bofire.data_models.kernels.shape import WassersteinKernel | ||
| from bofire.data_models.surrogates.botorch_surrogates import BotorchSurrogates | ||
| from bofire.data_models.surrogates.linear import LinearSurrogate | ||
| from bofire.data_models.surrogates.mixed_single_task_gp import ( | ||
| MixedSingleTaskGPSurrogate, | ||
| ) | ||
| from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate | ||
| from bofire.data_models.surrogates.polynomial import PolynomialSurrogate | ||
| from bofire.data_models.surrogates.robust_single_task_gp import ( | ||
| RobustSingleTaskGPSurrogate, | ||
| ) | ||
| from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate | ||
| from bofire.data_models.surrogates.single_task_gp import ( | ||
| SingleTaskGPHyperconfig, | ||
| SingleTaskGPSurrogate, | ||
| ) | ||
| from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate | ||
|
|
||
| # Patch AnyPrior fields | ||
| for model_cls, field_name in [ | ||
|
jduerholt marked this conversation as resolved.
Outdated
|
||
| (RBFKernel, "lengthscale_prior"), | ||
| (MaternKernel, "lengthscale_prior"), | ||
| (SphericalLinearKernel, "lengthscale_prior"), | ||
| (HammingDistanceKernel, "lengthscale_prior"), | ||
| (IndexKernel, "prior"), | ||
| (PositiveIndexKernel, "prior"), | ||
| (PositiveIndexKernel, "task_prior"), | ||
| (PositiveIndexKernel, "diag_prior"), | ||
| (WedgeKernel, "lengthscale_prior"), | ||
| (WedgeKernel, "angle_prior"), | ||
| (WedgeKernel, "radius_prior"), | ||
| (WassersteinKernel, "lengthscale_prior"), | ||
| (SingleTaskGPSurrogate, "noise_prior"), | ||
| (MultiTaskGPSurrogate, "noise_prior"), | ||
| (MixedSingleTaskGPSurrogate, "noise_prior"), | ||
| (TanimotoGPSurrogate, "noise_prior"), | ||
| (PiecewiseLinearGPSurrogate, "outputscale_prior"), | ||
| (PiecewiseLinearGPSurrogate, "noise_prior"), | ||
| (PolynomialSurrogate, "noise_prior"), | ||
| (LinearSurrogate, "noise_prior"), | ||
| (RobustSingleTaskGPSurrogate, "noise_prior"), | ||
| ]: | ||
| _patch_field(model_cls, field_name, AnyPrior) | ||
|
|
||
| # Patch AnyPriorConstraint fields | ||
| for model_cls, field_name in [ | ||
| (RBFKernel, "lengthscale_constraint"), | ||
| (MaternKernel, "lengthscale_constraint"), | ||
| (SphericalLinearKernel, "lengthscale_constraint"), | ||
| (HammingDistanceKernel, "lengthscale_constraint"), | ||
| (IndexKernel, "var_constraint"), | ||
| (PositiveIndexKernel, "var_constraint"), | ||
| (WedgeKernel, "lengthscale_constraint"), | ||
| (ScaleKernel, "outputscale_constraint"), | ||
| (SingleTaskGPHyperconfig, "lengthscale_constraint"), | ||
| (SingleTaskGPHyperconfig, "outputscale_constraint"), | ||
| ]: | ||
| _patch_field(model_cls, field_name, AnyPriorConstraint) | ||
|
|
||
| # Rebuild in dependency order: | ||
| # 1. Leaf kernel models | ||
| for cls in [ | ||
| RBFKernel, | ||
| MaternKernel, | ||
| SphericalLinearKernel, | ||
| HammingDistanceKernel, | ||
| IndexKernel, | ||
| PositiveIndexKernel, | ||
| WassersteinKernel, | ||
| ]: | ||
| cls.model_rebuild(force=True) | ||
|
|
||
| # 2. Conditional kernels (reference leaf kernels) | ||
| WedgeKernel.model_rebuild(force=True) | ||
|
|
||
| # 3. Aggregation kernels (embed leaf kernel types) | ||
| for cls in [ | ||
| AdditiveKernel, | ||
| MultiplicativeKernel, | ||
| ScaleKernel, | ||
| PolynomialFeatureInteractionKernel, | ||
| ]: | ||
| cls.model_rebuild(force=True) | ||
|
|
||
| # 4. Surrogate models | ||
| SingleTaskGPHyperconfig.model_rebuild(force=True) | ||
| for cls in [ | ||
| SingleTaskGPSurrogate, | ||
| MultiTaskGPSurrogate, | ||
| MixedSingleTaskGPSurrogate, | ||
| TanimotoGPSurrogate, | ||
| PiecewiseLinearGPSurrogate, | ||
| PolynomialSurrogate, | ||
| LinearSurrogate, | ||
| RobustSingleTaskGPSurrogate, | ||
| ]: | ||
| cls.model_rebuild(force=True) | ||
|
|
||
| # 5. BotorchSurrogates | ||
| BotorchSurrogates.model_rebuild(force=True) | ||
|
|
||
|
|
||
| def register_prior(data_model_cls: Type[Prior]) -> None: | ||
| """Register a custom prior type so it is accepted in AnyPrior fields. | ||
|
|
||
| This appends the type to the internal registry, rebuilds the | ||
| ``AnyPrior`` union, and calls ``model_rebuild`` on all dependent | ||
| Pydantic models (kernels, surrogates) so that the new type is accepted. | ||
|
|
||
| Args: | ||
| data_model_cls: A concrete subclass of ``Prior``. | ||
| """ | ||
| global AnyPrior | ||
| if data_model_cls in _PRIOR_TYPES: | ||
| return | ||
| _PRIOR_TYPES.append(data_model_cls) | ||
| AnyPrior = Union[tuple(_PRIOR_TYPES)] | ||
| _rebuild_dependent_models() | ||
|
|
||
|
|
||
| def register_prior_constraint(data_model_cls: Type) -> None: | ||
| """Register a custom prior constraint type so it is accepted in AnyPriorConstraint fields. | ||
|
|
||
| This appends the type to the internal registry, rebuilds the | ||
| ``AnyPriorConstraint`` union, and calls ``model_rebuild`` on all | ||
| dependent Pydantic models. | ||
|
|
||
| Args: | ||
| data_model_cls: A concrete subclass of ``PriorConstraint`` or ``Interval``. | ||
| """ | ||
| global AnyPriorConstraint | ||
| if data_model_cls in _PRIOR_CONSTRAINT_TYPES: | ||
| return | ||
| _PRIOR_CONSTRAINT_TYPES.append(data_model_cls) | ||
| AnyPriorConstraint = Union[tuple(_PRIOR_CONSTRAINT_TYPES)] | ||
| _rebuild_dependent_models() | ||
|
|
||
|
|
||
| # these are priors that are generally applicable | ||
| # and do not depend on problem specific extra parameters | ||
| AnyGeneralPrior = Union[GammaPrior, NormalPrior, LKJPrior, LogNormalPrior] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use
listinstead ofList. Don't importList.