Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions bofire/data_models/domain/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
AnyConstraint,
ConstraintNotFulfilledError,
InterpointConstraint,
LinearEqualityConstraint,
LinearInequalityConstraint,
NChooseKConstraint,
NonlinearConstraint,
ProductConstraint,
)
from bofire.data_models.domain.constraints import Constraints
from bofire.data_models.domain.features import Inputs, Outputs
Expand Down Expand Up @@ -138,6 +142,60 @@ def validate_constraints(self):
c.validate_inputs(self.inputs)
return self

def is_nchoosek_pruning_applicable(self) -> bool:
"""Check if greedy pruning can be used for NChooseK constraints.

Based on the BONSAI algorithm (https://arxiv.org/abs/2602.07144).
Pruning is applicable when:
1. There is at least one NChooseK constraint in the domain.
2. No feature involved in any NChooseK constraint appears in any
nonlinear (Product, Nonlinear) or interpoint constraint. Overlap
with linear equality/inequality constraints is allowed and handled
via QP projection + local acquisition function optimization.

Returns:
bool: True if pruning can be safely applied.
"""
nchoosek_constraints = self.constraints.get(NChooseKConstraint)
if len(nchoosek_constraints) == 0:
return False

# Collect features from constraints that cannot be handled by QP
blocking_constraint_features: set[str] = set()
for c in self.constraints.get(
includes=[ProductConstraint, NonlinearConstraint, InterpointConstraint]
):
blocking_constraint_features.update(c.features)

# Check that no NChooseK feature overlaps with blocking constraints
for c in nchoosek_constraints:
assert isinstance(c, NChooseKConstraint)
if blocking_constraint_features.intersection(c.features):
return False

return True

def has_nchoosek_linear_overlap(self) -> bool:
"""Check if any NChooseK feature also appears in a linear constraint.

Used to determine whether QP projection is needed during pruning.

Returns:
bool: True if there is overlap between NChooseK and linear constraints.
"""
nchoosek_features: set[str] = set()
for c in self.constraints.get(NChooseKConstraint):
assert isinstance(c, NChooseKConstraint)
nchoosek_features.update(c.features)

linear_features: set[str] = set()
for c in self.constraints.get(
includes=[LinearEqualityConstraint, LinearInequalityConstraint]
):
linear_features.update(c.features)

return bool(nchoosek_features.intersection(linear_features))

# TODO: tidy this up
def get_nchoosek_combinations(self, exhaustive: bool = False):
"""Get all possible NChooseK combinations
Expand Down
23 changes: 17 additions & 6 deletions bofire/strategies/predictives/acqf_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,14 @@ def _get_optimizer_options(self, domain: Domain) -> Dict[str, int]:

"""
assert self.batch_limit is not None
pruning_applicable = domain.is_nchoosek_pruning_applicable()
constraint_types = [ProductConstraint]
if not pruning_applicable:
constraint_types.append(NChooseKConstraint)
return {
"batch_limit": (
self.batch_limit
if len(
domain.constraints.get([NChooseKConstraint, ProductConstraint]),
)
== 0
if len(domain.constraints.get(constraint_types)) == 0
else 1
),
"maxiter": self.maxiter,
Expand All @@ -489,9 +490,11 @@ def _determine_optimizer(self, domain: Domain, n_acqfs) -> OptimizerEnum:
)
if n_categorical_combinations == 1:
return OptimizerEnum.OPTIMIZE_ACQF
exclude_nchoosek = domain.is_nchoosek_pruning_applicable()
if (
n_categorical_combinations <= ALTERNATING_OPTIMIZER_THRESHOLD
or len(get_nonlinear_constraints(domain)) > 0
or len(get_nonlinear_constraints(domain, exclude_nchoosek=exclude_nchoosek))
> 0
):
return OptimizerEnum.OPTIMIZE_ACQF_MIXED
return OptimizerEnum.OPTIMIZE_ACQF_MIXED_ALTERNATING
Expand All @@ -517,7 +520,15 @@ def _get_arguments_for_optimizer(
equality_constraints = get_linear_constraints(
domain, constraint=LinearEqualityConstraint
)
if len(nonlinear_constraints := get_nonlinear_constraints(domain)) == 0:
exclude_nchoosek = domain.is_nchoosek_pruning_applicable()
if (
len(
nonlinear_constraints := get_nonlinear_constraints(
domain, exclude_nchoosek=exclude_nchoosek
)
)
== 0
):
ic_generator = None
ic_gen_kwargs = {}
else:
Expand Down
Loading
Loading