Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions bofire/strategies/doe/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib.util
import re
import sys
import warnings
from copy import copy
from itertools import combinations
from typing import List, Optional, Tuple, Union, cast
Expand All @@ -22,7 +23,11 @@
NonlinearInequalityConstraint,
)
from bofire.data_models.domain.api import Domain, Inputs
from bofire.data_models.features.api import CategoricalInput, NumericalInput
from bofire.data_models.features.api import (
CategoricalInput,
DiscreteInput,
NumericalInput,
)
from bofire.data_models.features.continuous import ContinuousInput
from bofire.data_models.strategies.api import RandomStrategy as RandomStrategyDataModel
from bofire.strategies.doe.doe_problem import (
Expand Down Expand Up @@ -78,9 +83,20 @@ def formula_str_to_fully_continuous(
pattern = r"\b" + re.escape(cat_input.key) + r"\b"
formula = re.sub(pattern, "(" + f"{one_hot_terms}" + ")", formula)

return str(
Formula(formula)
formula = Formula(
formula
) # formula casting for expansion of terms like (a+b)*(c+d)
for _input in inputs.get([DiscreteInput]):
for k in range(
2, 99
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add sth like len(_input.values)+1 instead of 99 which is arbitary?

): # arbitrary upper bound on number of levels of discrete input
if (len(_input.values) <= k) and (_input.key + f" ** {k}" in formula.root):
warnings.warn(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so now it gives a warning but continues. i guess that is fine. alternatively we could throw a value error I guess

f"Discrete input {_input.key} with {len(_input.values)} levels cannot represent a term of order {k} or higher.",
UserWarning,
)
break
return str(formula)


def get_formula_from_string(
Expand All @@ -92,8 +108,7 @@ def get_formula_from_string(

Args:
model_type (str or Formula): A formula containing all model terms.
domain (Domain): A domain that nests necessary information on
how to translate a problem to a formula. Contains a problem.
inputs (Inputs, optional): The inputs to be used in the formula. Defaults to None. If the model_type is a string describing a model type (e.g. "linear"), inputs must be provided to determine the formula. If the model_type is already a formula, inputs are not necessary and ignored if provided.
rhs_only (bool): The function returns only the right hand side of the formula if set to True.

Returns:
Expand Down Expand Up @@ -239,8 +254,11 @@ def quadratic_terms(
A string describing the model that was given as string or keyword.

"""
_inputs = list(inputs.get([ContinuousInput])) + [
input for input in inputs.get([DiscreteInput]) if len(input.values) > 2
]

formula = "".join(["{" + input.key + "**2} + " for input in inputs])
formula = "".join(["{" + input.key + "**2} + " for input in _inputs])
return formula


Expand Down
98 changes: 90 additions & 8 deletions tests/bofire/strategies/doe/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ def test_n_zero_eigvals_constrained():
# thus there was one degree of freedom more if quadratic terms where added.
# Here, discretes are sampled within their respective domain, thus discrete2==discrete2**2 always
# thus we have one degree of freedom less.
# discrete2**2 (with only two levels) is no longer in the formula so counts for linear-and quadratic and quadratic are updated accordingly.
assert n_zero_eigvals(domain, "linear") == 1
assert n_zero_eigvals(domain, "linear-and-quadratic") == 2
assert n_zero_eigvals(domain, "linear-and-quadratic") == 1
assert n_zero_eigvals(domain, "linear-and-interactions") == 3
assert n_zero_eigvals(domain, "fully-quadratic") == 7
assert n_zero_eigvals(domain, "fully-quadratic") == 6

# TODO: NChooseK?

Expand Down Expand Up @@ -234,7 +235,7 @@ def test_number_of_model_terms():
formula = get_formula_from_string(
inputs=domain.inputs, model_type="linear-and-quadratic"
)
assert len(formula) == 11
assert len(formula) == 10 # discrete2 has only 2 levels, no quadratic term

formula = get_formula_from_string(
inputs=domain.inputs,
Expand All @@ -245,7 +246,7 @@ def test_number_of_model_terms():
formula = get_formula_from_string(
inputs=domain.inputs, model_type="fully-quadratic"
)
assert len(formula) == 21
assert len(formula) == 20 # discrete2 has only 2 levels, no quadratic term


def test_constraints_as_scipy_constraints():
Expand Down Expand Up @@ -750,7 +751,7 @@ def test_convert_formula_to_string():
def test_formula_discrete_handled_like_continuous():
domain_w_discrete = Domain.from_lists(
inputs=[ContinuousInput(key=f"x{i}", bounds=[0, 1]) for i in range(3)]
+ [DiscreteInput(key=f"x{i}", values=[0, 1]) for i in range(3, 5)],
+ [DiscreteInput(key=f"x{i}", values=[0, 1, 2]) for i in range(3, 5)],
outputs=[ContinuousOutput(key="y")],
)
domain_wo_discrete = Domain.from_lists(
Expand All @@ -774,6 +775,48 @@ def test_formula_discrete_handled_like_continuous():
assert formula_w_discrete == formula_wo_discrete


def test_formula_discrete_too_few_levels():
domain_w_discrete = Domain.from_lists(
inputs=[ContinuousInput(key=f"x{i}", bounds=[0, 1]) for i in range(3)]
+ [DiscreteInput(key=f"x{i}", values=[0, 1]) for i in range(3, 5)],
outputs=[ContinuousOutput(key="y")],
)
domain_wo_discrete = Domain.from_lists(
inputs=[ContinuousInput(key=f"x{i}", bounds=[0, 1]) for i in range(3)]
+ [ContinuousInput(key=f"x{i}", bounds=[0, 1]) for i in range(3, 5)],
outputs=[ContinuousOutput(key="y")],
)

for model_type in [
"linear",
"linear-and-interactions",
]:
formula_w_discrete = get_formula_from_string(
inputs=domain_w_discrete.inputs, model_type=model_type
)
formula_wo_discrete = get_formula_from_string(
inputs=domain_wo_discrete.inputs, model_type=model_type
)
assert formula_w_discrete == formula_wo_discrete

for model_type in [
"linear-and-quadratic",
"fully-quadratic",
]:
formula_w_discrete = str(
get_formula_from_string(
inputs=domain_w_discrete.inputs, model_type=model_type
)
)

formula_wo_discrete = str(
get_formula_from_string(
inputs=domain_wo_discrete.inputs, model_type=model_type
)
).replace(" + x3 ** 2 + x4 ** 2", "")
assert formula_w_discrete == formula_wo_discrete


def test_formula_str_to_fully_continuous():
# Create a small example problem with categorical, continuous, and discrete variables
inputs = Inputs(
Expand Down Expand Up @@ -863,6 +906,47 @@ def test_formula_str_to_fully_continuous():
), f"Expected: {expected_formula}\nGot: {continuous_formula}"


def test_formula_str_does_not_match_discrete_levels_emmits_warning():
# Create a small example problem with categorical, continuous, and discrete variables

inputs = Inputs(
features=[
CategoricalInput(
key="color",
categories=["red", "blue", "green"],
),
ContinuousInput(
key="color_intensity",
bounds=(0.0, 1.0),
),
CategoricalInput(
key="material",
categories=["plastic", "metal"],
),
ContinuousInput(
key="temperature",
bounds=(20.0, 100.0),
),
DiscreteInput(
key="pressure",
values=[0, 1],
),
]
)

# Define a custom formula with interactions among categorical variables
# This includes interaction between color and material
custom_formula = "color + material + temperature + { pressure ** 2 } + color:material + color_intensity"
with pytest.warns(
UserWarning,
match="Discrete input pressure with 2 levels cannot represent a term of order 2 or higher.",
):
formula_str_to_fully_continuous(
formula=custom_formula,
inputs=inputs,
)


def test_formula_str_to_fully_continuous_only_categoricals():
# Create a small example problem with only categorical variables
inputs = Inputs(
Expand Down Expand Up @@ -946,6 +1030,4 @@ def only_continuous_inputs_formula_str_to_fully_continuous():


if __name__ == "__main__":
test_formula_str_to_fully_continuous()
test_formula_str_to_fully_continuous_only_categoricals()
only_continuous_inputs_formula_str_to_fully_continuous()
test_formula_str_does_not_match_discrete_levels_emmits_warning()
Loading