diff --git a/.gitignore b/.gitignore index b72282ba3c..f49473c966 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,10 @@ build # VSCode .vscode +# Virtual environments +.venv +.env + # Testing .tox .coverage @@ -34,5 +38,6 @@ htmlcov # Folders that are temporarily created when building the documentation docs/_autosummary +docs/_build docs/examples docs/sdk \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index e24cc0ccb9..599e83dee8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - API diagram in user guide +- `Metadata` and `MeasurableMetadata` classes providing optional information for BayBE + objects +- `Objective` now has a `metadata` attribute as well as a `description` property +- `Target` and `Parameter` now have a `metadata` attribute as well as `description` and + `unit` properties ### Fixed - `Campaign` no longer allows overlapping names between parameters and targets diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c65d55f27d..6e62d74dea 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -31,4 +31,6 @@ - Fabian Liebig (Merck KGaA, Darmstadt, Germany):\ Benchmarking structure and persistence capabilities for benchmarking results - Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dübendorf, Switzerland):\ - SHAP explainers for insights \ No newline at end of file + SHAP explainers for insights +- Tobias Plötz (Merck KGaA, Darmstadt, Germany):\ + Metadata system \ No newline at end of file diff --git a/baybe/objectives/base.py b/baybe/objectives/base.py index dc28932076..6412521f18 100644 --- a/baybe/objectives/base.py +++ b/baybe/objectives/base.py @@ -6,7 +6,7 @@ import cattrs import pandas as pd -from attrs import define +from attrs import define, field from baybe.serialization.core import ( converter, @@ -15,6 +15,7 @@ ) from baybe.serialization.mixin import SerialMixin from baybe.targets.base import Target +from baybe.utils.metadata import Metadata, to_metadata # TODO: Reactive slots in all classes once cached_property is supported: # https://github.com/python-attrs/attrs/issues/164 @@ -27,6 +28,18 @@ class Objective(ABC, SerialMixin): is_multi_output: ClassVar[bool] """Class variable indicating if the objective produces multiple outputs.""" + metadata: Metadata = field( + factory=Metadata, + converter=lambda x: to_metadata(x, Metadata), + kw_only=True, + ) + """Optional metadata containing description and other information.""" + + @property + def description(self) -> str | None: + """The description of the objective.""" + return self.metadata.description + @property @abstractmethod def targets(self) -> tuple[Target, ...]: diff --git a/baybe/parameters/__init__.py b/baybe/parameters/__init__.py index 1424f93364..93e62b6ee9 100644 --- a/baybe/parameters/__init__.py +++ b/baybe/parameters/__init__.py @@ -12,12 +12,14 @@ NumericalDiscreteParameter, ) from baybe.parameters.substance import SubstanceParameter +from baybe.utils.metadata import MeasurableMetadata __all__ = [ "CategoricalEncoding", "CategoricalParameter", "CustomDiscreteParameter", "CustomEncoding", + "MeasurableMetadata", "NumericalContinuousParameter", "NumericalDiscreteParameter", "SubstanceEncoding", diff --git a/baybe/parameters/base.py b/baybe/parameters/base.py index 72497104ba..68a7113232 100644 --- a/baybe/parameters/base.py +++ b/baybe/parameters/base.py @@ -22,6 +22,7 @@ unstructure_base, ) from baybe.utils.basic import to_tuple +from baybe.utils.metadata import MeasurableMetadata, to_metadata if TYPE_CHECKING: from baybe.searchspace.continuous import SubspaceContinuous @@ -48,6 +49,13 @@ class Parameter(ABC, SerialMixin): name: str = field(validator=(instance_of(str), min_len(1))) """The name of the parameter""" + metadata: MeasurableMetadata = field( + factory=MeasurableMetadata, + converter=lambda x: to_metadata(x, MeasurableMetadata), + kw_only=True, + ) + """Optional metadata containing description, unit, and other information.""" + @abstractmethod def is_in_range(self, item: Any) -> bool: """Return whether an item is within the parameter range. @@ -88,6 +96,16 @@ def to_searchspace(self) -> SearchSpace: def summary(self) -> dict: """Return a custom summarization of the parameter.""" + @property + def description(self) -> str | None: + """The description of the parameter.""" + return self.metadata.description + + @property + def unit(self) -> str | None: + """The unit of measurement for the parameter.""" + return self.metadata.unit + @define(frozen=True, slots=False) class DiscreteParameter(Parameter, ABC): diff --git a/baybe/parameters/enum.py b/baybe/parameters/enum.py index 5f0bbe1c19..fc309f2804 100644 --- a/baybe/parameters/enum.py +++ b/baybe/parameters/enum.py @@ -27,7 +27,7 @@ class CustomEncoding(ParameterEncoding): class SubstanceEncoding(ParameterEncoding): """Available encodings for substance parameters from `scikit-fingerprints`_ package. - .. _scikit-fingerprints: https://scikit-fingerprints.github.io/scikit-fingerprints/ + .. _scikit-fingerprints: https://scikit-fingerprints.readthedocs.io/ """ ATOMPAIR = "ATOMPAIR" diff --git a/baybe/targets/__init__.py b/baybe/targets/__init__.py index 3e4712f8b3..b2487c8e14 100644 --- a/baybe/targets/__init__.py +++ b/baybe/targets/__init__.py @@ -3,9 +3,11 @@ from baybe.targets.binary import BinaryTarget from baybe.targets.enum import TargetMode, TargetTransformation from baybe.targets.numerical import NumericalTarget +from baybe.utils.metadata import MeasurableMetadata __all__ = [ "BinaryTarget", + "MeasurableMetadata", "NumericalTarget", "TargetMode", "TargetTransformation", diff --git a/baybe/targets/base.py b/baybe/targets/base.py index 272b7cd771..e3e3a6f84c 100644 --- a/baybe/targets/base.py +++ b/baybe/targets/base.py @@ -16,6 +16,7 @@ get_base_structure_hook, unstructure_base, ) +from baybe.utils.metadata import MeasurableMetadata, to_metadata if TYPE_CHECKING: from baybe.objectives import SingleTargetObjective @@ -31,6 +32,23 @@ class Target(ABC, SerialMixin): name: str = field() """The name of the target.""" + metadata: MeasurableMetadata = field( + factory=MeasurableMetadata, + converter=lambda x: to_metadata(x, MeasurableMetadata), + kw_only=True, + ) + """Optional metadata containing description, unit, and other information.""" + + @property + def description(self) -> str | None: + """The description of the target.""" + return self.metadata.description + + @property + def unit(self) -> str | None: + """The unit of measurement for the target.""" + return self.metadata.unit + def to_objective(self) -> SingleTargetObjective: """Create a single-task objective from the target.""" from baybe.objectives.single import SingleTargetObjective diff --git a/baybe/utils/metadata.py b/baybe/utils/metadata.py new file mode 100644 index 0000000000..31976ccba0 --- /dev/null +++ b/baybe/utils/metadata.py @@ -0,0 +1,118 @@ +"""Generic metadata system for BayBE objects.""" + +from __future__ import annotations + +from typing import Any, TypeVar + +import cattrs +from attrs import AttrsInstance, define, field, fields +from attrs.validators import deep_mapping, instance_of +from attrs.validators import optional as optional_v +from typing_extensions import override + +from baybe.serialization import SerialMixin, converter +from baybe.utils.basic import classproperty + +_TMetaData = TypeVar("_TMetaData", bound="Metadata") + + +@define(frozen=True) +class Metadata(SerialMixin): + """Metadata class providing basic information for BayBE objects.""" + + description: str | None = field( + default=None, validator=optional_v(instance_of(str)) + ) + """A description of the object.""" + + misc: dict[str, Any] = field( + factory=dict, + validator=deep_mapping( + mapping_validator=instance_of(dict), + key_validator=instance_of(str), + # FIXME: https://github.com/python-attrs/attrs/issues/1246 + value_validator=lambda *x: None, + ), + kw_only=True, + ) + """Additional user-defined metadata.""" + + @misc.validator + def _validate_misc(self, _, value: dict[str, Any]) -> None: + if inv := set(value).intersection(self._explicit_fields): + raise ValueError( + f"Miscellaneous metadata cannot contain the following fields: {inv}. " + f"Use the corresponding attributes instead." + ) + + @classproperty + def _explicit_fields(cls: type[AttrsInstance]) -> set[str]: + """The explicit metadata fields.""" # noqa: D401 + flds = fields(cls) + return {fld.name for fld in flds if fld.name != flds.misc.name} + + @property + def is_empty(self) -> bool: + """Check if metadata contains any meaningful information.""" + return self.description is None and not self.misc + + +@define(frozen=True) +class MeasurableMetadata(Metadata): + """Class providing metadata for BayBE :class:`Parameter` objects.""" + + unit: str | None = field(default=None, validator=optional_v(instance_of(str))) + """The unit of measurement for the parameter.""" + + @override + @property + def is_empty(self) -> bool: + """Check if metadata contains any meaningful information.""" + return super().is_empty and self.unit is None + + +def to_metadata( + value: dict[str, Any] | _TMetaData, cls: type[_TMetaData], / +) -> _TMetaData: + """Convert a dictionary to :class:`Metadata` (with :class:`Metadata` passthrough). + + Args: + value: The metadata input. + cls: The specific :class:`Metadata` subclass to convert to. + + Returns: + The created metadata instance of the requested :class:`Metadata` subclass. + + Raises: + TypeError: If the input is not a dictionary or of the specified + :class:`Metadata` type. + """ + if isinstance(value, cls): + return value + + if not isinstance(value, dict): + raise TypeError( + f"The input must be a dictionary or a '{cls.__name__}' instance. " + f"Got: {type(value)}" + ) + + # Separate known fields from unknown ones + return converter.structure(value, cls) + + +@converter.register_structure_hook +def _separate_metadata_fields(dct: dict[str, Any], cls: type[Metadata]) -> Metadata: + """Separate known fields from miscellaneous metadata.""" + dct = dct.copy() + explicit = {fld: dct.pop(fld, None) for fld in cls._explicit_fields} + return cls(**explicit, misc=dct) + + +@converter.register_unstructure_hook +def _flatten_misc_metadata(metadata: Metadata) -> dict[str, Any]: + """Flatten the metadata for serialization.""" + cls = type(metadata) + fn = cattrs.gen.make_dict_unstructure_fn(cls, converter) + dct = fn(metadata) + dct = dct | dct.pop(fields(Metadata).misc.name) + return dct diff --git a/docs/conf.py b/docs/conf.py index e3c0bdf7e5..9473628ae4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -263,7 +263,7 @@ "python": ("https://docs.python.org/3", None), "pandas": ("https://pandas.pydata.org/docs/", None), "polars": ("https://docs.pola.rs/api/python/stable/", None), - "skfp": ("https://scikit-fingerprints.github.io/scikit-fingerprints/", None), + "skfp": ("https://scikit-fingerprints.readthedocs.io/latest/", None), "sklearn": ("https://scikit-learn.org/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "torch": ("https://pytorch.org/docs/main/", None), diff --git a/docs/userguide/parameters.md b/docs/userguide/parameters.md index 12d3b2ef88..ddbebc085b 100644 --- a/docs/userguide/parameters.md +++ b/docs/userguide/parameters.md @@ -5,7 +5,7 @@ [`TaskParameter`]: baybe.parameters.categorical.TaskParameter [`CustomDiscreteParameter`]: baybe.parameters.custom.CustomDiscreteParameter [`SubstanceEncoding`]: baybe.parameters.enum.SubstanceEncoding -[scikit-fingerprints]: https://scikit-fingerprints.github.io/scikit-fingerprints/ +[scikit-fingerprints]: https://scikit-fingerprints.readthedocs.io # Parameters diff --git a/tests/hypothesis_strategies/metadata.py b/tests/hypothesis_strategies/metadata.py new file mode 100644 index 0000000000..e3f96c68d4 --- /dev/null +++ b/tests/hypothesis_strategies/metadata.py @@ -0,0 +1,40 @@ +"""Hypothesis strategies for metadata.""" + +import hypothesis.strategies as st +from hypothesis import assume + +from baybe.utils.metadata import MeasurableMetadata, Metadata + +_descriptions = st.one_of(st.none(), st.text(min_size=0)) +"""A strategy generating metadata descriptions.""" + + +@st.composite +def _miscs(draw: st.DrawFn, cls: type[Metadata]): + """Generates miscellaneous metadata for various metadata classes.""" + misc = draw( + st.dictionaries( + st.text(min_size=0), + st.one_of(st.text(), st.integers(), st.floats(allow_nan=False)), + max_size=5, + ) + ) + assume(not cls._explicit_fields.intersection(misc)) + return misc + + +@st.composite +def metadata(draw: st.DrawFn): + """Generate :class:`baybe.utils.metadata.Metadata`.""" + description = draw(_descriptions) + misc = draw(_miscs(Metadata)) + return Metadata(description=description, misc=misc) + + +@st.composite +def measurable_metadata(draw: st.DrawFn): + """Generate :class:`baybe.parameters.base.MeasurableMetadata`.""" + description = draw(_descriptions) + unit = draw(st.one_of(st.none(), st.text(min_size=0))) + misc = draw(_miscs(MeasurableMetadata)) + return MeasurableMetadata(description=description, unit=unit, misc=misc) diff --git a/tests/hypothesis_strategies/objectives.py b/tests/hypothesis_strategies/objectives.py index ea0d9e613c..5c46b6ce04 100644 --- a/tests/hypothesis_strategies/objectives.py +++ b/tests/hypothesis_strategies/objectives.py @@ -10,6 +10,7 @@ from baybe.targets.numerical import _VALID_TRANSFORMATIONS from ..hypothesis_strategies.basic import finite_floats +from ..hypothesis_strategies.metadata import measurable_metadata, metadata from ..hypothesis_strategies.targets import numerical_targets from ..hypothesis_strategies.utils import intervals as st_intervals @@ -20,9 +21,12 @@ ) -def single_target_objectives(): +@st.composite +def single_target_objectives(draw: st.DrawFn): """Generate :class:`baybe.objectives.single.SingleTargetObjective`.""" - return st.builds(SingleTargetObjective, target=numerical_targets()) + target = draw(numerical_targets()) + objective_metadata = draw(metadata()) + return SingleTargetObjective(target=target, metadata=objective_metadata) @st.composite @@ -37,7 +41,10 @@ def desirability_objectives(draw: st.DrawFn): ) ) scalarizer = draw(st.sampled_from(Scalarizer)) - return DesirabilityObjective(targets, weights, scalarizer) + objective_metadata = draw(metadata()) + return DesirabilityObjective( + targets, weights, scalarizer, metadata=objective_metadata + ) @st.composite @@ -55,8 +62,14 @@ def _pareto_targets(draw: st.DrawFn): transformation = None bounds = None + target_metadata = draw(measurable_metadata()) + return NumericalTarget( - name=name, mode=mode, bounds=bounds, transformation=transformation + name=name, + mode=mode, + bounds=bounds, + transformation=transformation, + metadata=target_metadata, ) @@ -68,4 +81,6 @@ def _pareto_targets(draw: st.DrawFn): @st.composite def pareto_objectives(draw: st.DrawFn): """Generate :class:`baybe.objectives.pareto.ParetoObjective`.""" - return ParetoObjective(draw(_pareto_target_lists)) + targets = draw(_pareto_target_lists) + objective_metadata = draw(metadata()) + return ParetoObjective(targets, metadata=objective_metadata) diff --git a/tests/hypothesis_strategies/parameters.py b/tests/hypothesis_strategies/parameters.py index 3a08531b3e..2d200e32bd 100644 --- a/tests/hypothesis_strategies/parameters.py +++ b/tests/hypothesis_strategies/parameters.py @@ -20,6 +20,7 @@ from baybe.utils.numerical import DTypeFloatNumpy from ..hypothesis_strategies.basic import finite_floats +from .metadata import measurable_metadata from .utils import intervals decorrelations = st.one_of( @@ -114,7 +115,10 @@ def numerical_discrete_parameters( exclude_max=True, ) ) - return NumericalDiscreteParameter(name=name, values=values, tolerance=tolerance) + param_metadata = draw(measurable_metadata()) + return NumericalDiscreteParameter( + name=name, values=values, tolerance=tolerance, metadata=param_metadata + ) @st.composite @@ -122,7 +126,10 @@ def numerical_continuous_parameters(draw: st.DrawFn): """Generate :class:`baybe.parameters.numerical.NumericalContinuousParameter`.""" name = draw(parameter_names) bounds = draw(intervals(exclude_half_bounded=True, exclude_fully_unbounded=True)) - return NumericalContinuousParameter(name=name, bounds=bounds) + param_metadata = draw(measurable_metadata()) + return NumericalContinuousParameter( + name=name, bounds=bounds, metadata=param_metadata + ) @st.composite @@ -132,9 +139,13 @@ def categorical_parameters(draw: st.DrawFn): values = draw(categories) encoding = draw(st.sampled_from(CategoricalEncoding)) active_values = draw(_active_values(values)) - + param_metadata = draw(measurable_metadata()) return CategoricalParameter( - name=name, values=values, encoding=encoding, active_values=active_values + name=name, + values=values, + encoding=encoding, + active_values=active_values, + metadata=param_metadata, ) @@ -144,7 +155,10 @@ def task_parameters(draw: st.DrawFn): name = draw(parameter_names) values = draw(categories) active_values = draw(_active_values(values)) - return TaskParameter(name=name, values=values, active_values=active_values) + param_metadata = draw(measurable_metadata()) + return TaskParameter( + name=name, values=values, active_values=active_values, metadata=param_metadata + ) @st.composite @@ -161,12 +175,15 @@ def substance_parameters(draw: st.DrawFn): encodings.remove(SubstanceEncoding.RDKIT) encoding = draw(st.sampled_from(encodings)) + param_metadata = draw(measurable_metadata()) + return SubstanceParameter( name=name, data=data, decorrelate=decorrelate, encoding=encoding, active_values=active_values, + metadata=param_metadata, ) @@ -177,9 +194,13 @@ def custom_parameters(draw: st.DrawFn): data = draw(custom_descriptors()) decorrelate = draw(decorrelations) active_values = draw(_active_values(data.index.values)) - + param_metadata = draw(measurable_metadata()) return CustomDiscreteParameter( - name=name, data=data, decorrelate=decorrelate, active_values=active_values + name=name, + data=data, + decorrelate=decorrelate, + active_values=active_values, + metadata=param_metadata, ) diff --git a/tests/hypothesis_strategies/targets.py b/tests/hypothesis_strategies/targets.py index 68536d5abc..11699f94c5 100644 --- a/tests/hypothesis_strategies/targets.py +++ b/tests/hypothesis_strategies/targets.py @@ -7,6 +7,7 @@ from baybe.targets.numerical import _VALID_TRANSFORMATIONS, NumericalTarget from baybe.utils.interval import Interval +from .metadata import measurable_metadata from .utils import intervals as st_intervals target_name = st.text(min_size=1) @@ -36,9 +37,13 @@ def numerical_targets( or transformation is not None, ) bounds = draw(bounds_strategy) - + target_metadata = draw(measurable_metadata()) return NumericalTarget( - name=name, mode=mode, bounds=bounds, transformation=transformation + name=name, + mode=mode, + bounds=bounds, + transformation=transformation, + metadata=target_metadata, ) @@ -53,7 +58,13 @@ def binary_targets(draw: st.DrawFn): """A strategy that generates binary targets.""" name = draw(target_name) choices = draw(st.lists(choice_values, min_size=2, max_size=2, unique=True)) - return BinaryTarget(name, success_value=choices[0], failure_value=choices[1]) + target_metadata = draw(measurable_metadata()) + return BinaryTarget( + name, + success_value=choices[0], + failure_value=choices[1], + metadata=target_metadata, + ) targets = st.one_of([binary_targets(), numerical_targets()]) diff --git a/tests/integration/test_metadata_integration.py b/tests/integration/test_metadata_integration.py new file mode 100644 index 0000000000..46de266b35 --- /dev/null +++ b/tests/integration/test_metadata_integration.py @@ -0,0 +1,66 @@ +"""Integration tests for metadata with BayBE components.""" + +import pytest +from pytest import param + +from baybe.objectives.single import SingleTargetObjective +from baybe.parameters.numerical import NumericalDiscreteParameter +from baybe.targets.enum import TargetMode +from baybe.targets.numerical import NumericalTarget +from baybe.utils.metadata import MeasurableMetadata, Metadata, to_metadata + +TMetadata = Metadata | dict | None + + +def make_parameter(metadata: TMetadata = None) -> NumericalDiscreteParameter: + return NumericalDiscreteParameter( + "p", (1, 2), metadata=metadata or MeasurableMetadata() + ) + + +def make_target(metadata: TMetadata = None) -> NumericalTarget: + return NumericalTarget( + "yield", TargetMode.MAX, metadata=metadata or MeasurableMetadata() + ) + + +def make_objective(metadata: TMetadata = None) -> SingleTargetObjective: + return SingleTargetObjective( + target=make_target(), metadata=metadata or MeasurableMetadata() + ) + + +@pytest.mark.parametrize( + ("constructor", "metadata_cls"), + [ + param(make_parameter, MeasurableMetadata, id="parameter"), + param(make_target, MeasurableMetadata, id="target"), + param(make_objective, Metadata, id="objective"), + ], +) +class TestMetadataIntegration: + """Tests for metadata integration with BayBE objects.""" + + @pytest.mark.parametrize("as_dict", [True, False]) + def test_with_metadata(self, constructor, metadata_cls, as_dict: bool): + """BayBE objects accept, ingest, and surface metadata.""" + meta = dict(description="test", unit="m", other="value") + container = constructor( + metadata=meta if as_dict else to_metadata(meta, metadata_cls) + ) + + assert container.description == "test" + if metadata_cls is MeasurableMetadata: + assert container.metadata.misc == {"other": "value"} + assert container.unit == "m" + else: + assert container.metadata.misc == {"unit": "m", "other": "value"} + + def test_without_metadata(self, metadata_cls, constructor): + """BayBE objects without metadata have empty metadata and `None` properties.""" + container = constructor().metadata + + assert container.is_empty + assert container.description is None + if metadata_cls is MeasurableMetadata: + assert container.unit is None diff --git a/tests/serialization/test_metadata_serialization.py b/tests/serialization/test_metadata_serialization.py new file mode 100644 index 0000000000..8457b6b70a --- /dev/null +++ b/tests/serialization/test_metadata_serialization.py @@ -0,0 +1,43 @@ +"""Test serialization of metadata.""" + +import hypothesis.strategies as st +import pytest +from hypothesis import given +from pytest import param + +from baybe.utils.metadata import MeasurableMetadata, Metadata, to_metadata +from tests.hypothesis_strategies.metadata import measurable_metadata, metadata + + +@pytest.mark.parametrize( + "metadata_strategy", + [ + param(metadata(), id="Metadata"), + param(measurable_metadata(), id="MeasurableMetadata"), + ], +) +@given(data=st.data()) +def test_metadata_roundtrip(data, metadata_strategy): + """A serialization roundtrip yields an equivalent object.""" + meta = data.draw(metadata_strategy) + cls = type(meta) + string = meta.to_json() + meta2 = cls.from_json(string) + assert meta == meta2, (meta, meta2) + + +@pytest.mark.parametrize("cls", [Metadata, MeasurableMetadata]) +def test_field_separation(cls: type[Metadata]): + """Field separation adapts to the specific ``Metadata`` subclass and works + regardless of the conversion route. + """ # noqa: D205 + dct = {"description": "test", "unit": "m", "key": "value"} + if cls is Metadata: + metadata = cls(description="test", misc={"unit": "m", "key": "value"}) + elif issubclass(cls, MeasurableMetadata): + metadata = cls(description="test", unit="m", misc={"key": "value"}) + else: + raise ValueError(f"Unsupported class: {cls}") + via_converter = to_metadata(dct, cls) + via_from_dict = cls.from_dict(dct) + assert metadata == via_converter == via_from_dict diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 0000000000..736c47c47d --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,102 @@ +"""Tests for metadata functionality.""" + +import pytest +from pytest import param + +from baybe.utils.metadata import MeasurableMetadata, Metadata, to_metadata + + +class TestMetadata: + """Tests for `Metadata` class.""" + + def test_metadata_creation_basic(self): + """All metadata attributes are properly populated.""" + meta = MeasurableMetadata( + description="test", unit="kg", misc={"custom_field": "value"} + ) + assert meta.description == "test" + assert meta.unit == "kg" + assert meta.misc == {"custom_field": "value"} + + def test_metadata_creation_defaults(self): + """Metadata can be created with no content.""" + meta = MeasurableMetadata() + assert meta.description is None + assert meta.unit is None + assert meta.misc == {} + + @pytest.mark.parametrize( + ("metadata", "expected"), + [ + param({}, True, id="empty"), + param({"description": "test"}, False, id="description_only"), + param({"unit": "kg"}, False, id="unit_only"), + param({"misc": {"key": "value"}}, False, id="misc_only"), + ], + ) + @pytest.mark.parametrize("metadata_cls", [Metadata, MeasurableMetadata]) + def test_metadata_is_empty_detection(self, metadata_cls, metadata, expected): + """The is_empty property correctly identifies empty metadata.""" + if metadata_cls is Metadata and "unit" in metadata: + pytest.skip("Metadata class has no 'unit' attribute.") + assert metadata_cls(**metadata).is_empty == expected + + +class TestMetadataConverter: + """Tests for `to_metadata` function.""" + + def test_convert_metadata_instance(self): + """The converter passes through Metadata instances unchanged.""" + meta = MeasurableMetadata(description="test") + result = to_metadata(meta, MeasurableMetadata) + assert result is meta + + @pytest.mark.parametrize( + ("input_dict", "expected_desc", "expected_unit", "expected_misc"), + [ + param( + {"description": "test desc", "unit": "kg"}, + "test desc", + "kg", + {}, + id="known_fields_only", + ), + param( + {"description": "test", "custom_field": "value"}, + "test", + None, + {"custom_field": "value"}, + id="description_and_misc", + ), + param( + {"unit": "m", "other": 42}, + None, + "m", + {"other": 42}, + id="unit_and_misc", + ), + param( + {"custom1": "val1", "custom2": 123}, + None, + None, + {"custom1": "val1", "custom2": 123}, + id="misc_only", + ), + param( + {}, + None, + None, + {}, + id="empty_dict", + ), + ], + ) + def test_convert_dict_valid( + self, input_dict, expected_desc, expected_unit, expected_misc + ): + """Conversion from dict properly separates known and unknown attributes.""" + result = to_metadata(input_dict, MeasurableMetadata) + assert isinstance(result, MeasurableMetadata) + assert result.description == expected_desc + assert result.unit == expected_unit + assert result.misc == expected_misc diff --git a/tests/validation/test_metadata_validation.py b/tests/validation/test_metadata_validation.py new file mode 100644 index 0000000000..c283f39d5e --- /dev/null +++ b/tests/validation/test_metadata_validation.py @@ -0,0 +1,67 @@ +"""Validation tests for metadata.""" + +import pytest +from pytest import param + +from baybe.utils.metadata import MeasurableMetadata, Metadata, to_metadata + + +@pytest.mark.parametrize( + ("description", "misc", "error", "match"), + [ + param(0, {}, TypeError, "must be ", id="desc-non-str"), + param(None, 0, TypeError, "must be ", id="misc-non-dict"), + param( + None, + {0: 0}, + TypeError, + "must be ", + id="misc-non-str-keys", + ), + param( + None, + {"description": 0}, + ValueError, + "fields: {'description'}", + id="desc_in_misc", + ), + ], +) +def test_invalid_arguments_for_metadata(description, misc, error, match): + """Providing invalid arguments to base Metadata class raises an error.""" + with pytest.raises(error, match=match): + Metadata(description, misc=misc) + + +@pytest.mark.parametrize( + ("description", "unit", "misc", "error", "match"), + [ + param(None, 0, {}, TypeError, "must be ", id="unit-non-str"), + param( + None, None, {"unit": 0}, ValueError, "fields: {'unit'}", id="unit_in_misc" + ), + ], +) +def test_invalid_arguments_for_measurable_metadata( + description, unit, misc, error, match +): + """Providing invalid arguments raises an error.""" + with pytest.raises(error, match=match): + MeasurableMetadata(description, unit, misc=misc) + + +@pytest.mark.parametrize( + "invalid_input", + [ + param("string", id="string_input"), + param(123, id="number_input"), + param(["list"], id="list_input"), + param(object(), id="object_input"), + ], +) +def test_invalid_input_conversion(invalid_input): + """Converting invalid inputs raises an error.""" + with pytest.raises( + TypeError, match="must be a dictionary or a 'Metadata' instance." + ): + to_metadata(invalid_input, Metadata)