diff --git a/lale/lib/sklearn/logistic_regression.py b/lale/lib/sklearn/logistic_regression.py index e20428691..40547adc1 100644 --- a/lale/lib/sklearn/logistic_regression.py +++ b/lale/lib/sklearn/logistic_regression.py @@ -618,7 +618,7 @@ def score(self, X, y, sample_weight=None): break -if lale.operators.sklearn_version >= version.Version("1.7"): +if lale.operators.sklearn_version >= version.Version("1.8"): LogisticRegression = typing.cast( lale.operators.PlannedIndividualOp, LogisticRegression.customize_schema( diff --git a/lale/operators.py b/lale/operators.py index 4b31384f0..cf6706f24 100644 --- a/lale/operators.py +++ b/lale/operators.py @@ -349,6 +349,21 @@ def __or__(self, other: Union[Any, "Operator"]) -> "OperatorChoice": def __ror__(self, other: Union[Any, "Operator"]) -> "OperatorChoice": return make_choice(other, self) + def __sklearn_tags__(self): + """Provide sklearn compatibility for >=1.6""" + if hasattr(self, "_impl_instance"): + impl = self._impl_instance() + if hasattr(impl, "__sklearn_tags__"): + return impl.__sklearn_tags__() + elif hasattr(impl, "_more_tags"): + return impl._more_tags() + try: + from sklearn.utils._tags import default_tags + + return default_tags(self) + except ImportError: + return {} + def name(self) -> str: """Get the name of this operator instance.""" return self._name diff --git a/setup.py b/setup.py index 90d66f219..50f08fa2a 100644 --- a/setup.py +++ b/setup.py @@ -45,9 +45,9 @@ "numpy", "black>=22.1.0", "hyperopt>=0.2,<=0.2.7", - "jsonschema<=4.25.1", + "jsonschema<=5", "jsonsubschema>=0.0.6", - "scikit-learn>=1.0.0,<1.7.0", + "scikit-learn>=1.0.0,<1.8.0", "scipy", "pandas", "packaging",