diff --git a/lit_nlp/components/curves_test.py b/lit_nlp/components/curves_test.py index cef377c4..1b7dd194 100644 --- a/lit_nlp/components/curves_test.py +++ b/lit_nlp/components/curves_test.py @@ -51,7 +51,7 @@ def input_spec(self) -> lit_types.Spec: def output_spec(self) -> lit_types.Spec: return { 'pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'), - 'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label') + 'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'), } def predict_minibatch( @@ -64,10 +64,9 @@ def predict_example(ex: lit_types.JsonDict) -> tuple[float, float, float]: return TEST_DATA[x].prediction for example in inputs: - output.append({ - 'pred': predict_example(example), - 'aux_pred': [1 / 3, 1 / 3, 1 / 3] - }) + output.append( + {'pred': predict_example(example), 'aux_pred': [1 / 3, 1 / 3, 1 / 3]} + ) return output @@ -148,6 +147,43 @@ def test_model_output_is_missing_in_config(self): config={'Label': 'red'}, ) + @parameterized.named_parameters( + dict( + testcase_name='red', + label='red', + exp_roc=[(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)], + exp_pr=[(0.5, 0.5), (2 / 3, 1.0), (1.0, 0.5), (1.0, 0.0)], + ), + dict( + testcase_name='blue', + label='blue', + exp_roc=[(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)], + exp_pr=[ + (0.3333333333333333, 1.0), + (0.5, 1.0), + (1.0, 1.0), + (1.0, 0.0), + ], + ), + ) + def test_interpreter_honors_user_selected_label( + self, label: str, exp_roc: _Curve, exp_pr: _Curve + ): + """Tests a happy scenario when a user doesn't specify the class label.""" + curves_data = self.ci.run( + inputs=self.dataset.examples, + model=self.model, + dataset=self.dataset, + config={ + curves.TARGET_LABEL_KEY: label, + curves.TARGET_PREDICTION_KEY: 'pred', + }, + ) + self.assertIn(curves.ROC_DATA, curves_data) + self.assertIn(curves.PR_DATA, curves_data) + self.assertEqual(curves_data[curves.ROC_DATA], exp_roc) + self.assertEqual(curves_data[curves.PR_DATA], exp_pr) + def test_config_spec(self): """Tests that the interpreter config has correct fields of correct type.""" spec = self.ci.config_spec() diff --git a/pyproject.toml b/pyproject.toml index 2a27db38..66536476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "rouge-score>=0.1.2", "sacrebleu>=2.3.1", "saliency>=0.1.3", - "scikit-learn>=1.0.2", + "scikit-learn>=1.6.1", "scipy>=1.10.1", "shap>=0.42.0,<0.46.0", "six>=1.16.0", diff --git a/requirements.txt b/requirements.txt index 9c4707b0..caabea60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,7 +31,7 @@ requests>=2.31.0 rouge-score>=0.1.2 sacrebleu>=2.3.1 saliency>=0.1.3 -scikit-learn>=1.0.2 +scikit-learn>=1.6.1 scipy>=1.10.1 shap>=0.42.0,<0.46.0 six>=1.16.0