-
-
Notifications
You must be signed in to change notification settings - Fork 695
remove sklearn dependency from cohenkappa score calculation logic and applied custom calculation and updated tests #3731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
avishkarsonni
wants to merge
8
commits into
pytorch:master
Choose a base branch
from
avishkarsonni:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
9c1b2ab
remove sklearn dependency from cohenkappa score calculation logic and…
avishkarsonni d6782e7
updated the docstring with the correct changes
avishkarsonni c5eefb0
Merge branch 'pytorch:master' into master
avishkarsonni add7192
changed the implementation to ConfusionMatrix for calculation of Cohe…
avishkarsonni 8808d09
added the conversion of GPU tensors to CPU and then call the .double(…
avishkarsonni e80d750
Merge branch 'pytorch:master' into master
avishkarsonni 16bcc77
Changed the approach to the match the dtype of confision matrix to th…
avishkarsonni 39cc54b
update ConfusionMatrix import path in CohenKappa class docstring
avishkarsonni File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,6 @@ | ||
| import os | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
| import sklearn | ||
| import torch | ||
| from sklearn.metrics import cohen_kappa_score | ||
|
|
||
|
|
@@ -14,55 +12,15 @@ | |
| torch.manual_seed(12) | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def mock_no_sklearn(): | ||
| with patch.dict("sys.modules", {"sklearn.metrics": None}): | ||
| yield sklearn | ||
|
|
||
|
|
||
| def test_no_sklearn(mock_no_sklearn): | ||
| with pytest.raises(ModuleNotFoundError, match=r"This contrib module requires scikit-learn to be installed."): | ||
| CohenKappa() | ||
|
|
||
|
|
||
| def test_no_update(): | ||
| ck = CohenKappa() | ||
|
|
||
| with pytest.raises( | ||
| NotComputableError, match=r"EpochMetric must have at least one example before it can be computed" | ||
| NotComputableError, match=r"CohenKappa must have at least one example before it can be computed" | ||
| ): | ||
| ck.compute() | ||
|
|
||
|
|
||
| def test_input_types(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here, why do you remove these tests? |
||
| ck = CohenKappa() | ||
| ck.reset() | ||
| output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) | ||
| ck.update(output1) | ||
|
|
||
| with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"): | ||
| ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3)))) | ||
|
|
||
| with pytest.raises(ValueError, match=r"Incoherent types between input y and stored targets"): | ||
| ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32))) | ||
|
|
||
| with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"): | ||
| ck.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5)).long())) | ||
|
|
||
|
|
||
| def test_check_shape(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you remove these tests? |
||
| ck = CohenKappa() | ||
|
|
||
| with pytest.raises(ValueError, match=r"Predictions should be of shape"): | ||
| ck._check_shape((torch.tensor(0), torch.tensor(0))) | ||
|
|
||
| with pytest.raises(ValueError, match=r"Predictions should be of shape"): | ||
| ck._check_shape((torch.rand(4, 3, 1), torch.rand(4, 3))) | ||
|
|
||
| with pytest.raises(ValueError, match=r"Targets should be of shape"): | ||
| ck._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1))) | ||
|
|
||
|
|
||
| def test_cohen_kappa_wrong_weights_type(): | ||
| with pytest.raises(ValueError, match=r"Kappa Weighting type must be"): | ||
| ck = CohenKappa(weights=7) | ||
|
|
@@ -329,3 +287,70 @@ def _test_distrib_xla_nprocs(index): | |
| def test_distrib_xla_nprocs(xmp_executor): | ||
| n = int(os.environ["NUM_TPU_WORKERS"]) | ||
| xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) | ||
|
|
||
|
|
||
| # --- num_classes path tests --- | ||
|
|
||
|
|
||
| def test_num_classes_no_update(): | ||
| ck = CohenKappa(num_classes=3) | ||
| with pytest.raises( | ||
| NotComputableError, match=r"CohenKappa must have at least one example before it can be computed" | ||
| ): | ||
| ck.compute() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) | ||
| def test_num_classes_matches_dynamic(weights, available_device): | ||
| torch.manual_seed(42) | ||
| y_pred = torch.randint(0, 4, size=(60,)).long() | ||
| y = torch.randint(0, 4, size=(60,)).long() | ||
| batch_size = 10 | ||
|
|
||
| ck_dynamic = CohenKappa(weights=weights, device=available_device) | ||
| ck_fixed = CohenKappa(weights=weights, device=available_device, num_classes=4) | ||
|
|
||
| for ck in (ck_dynamic, ck_fixed): | ||
| ck.reset() | ||
| for i in range(60 // batch_size): | ||
| idx = i * batch_size | ||
| ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) | ||
|
|
||
| assert ck_dynamic.compute() == pytest.approx(ck_fixed.compute()) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) | ||
| def test_num_classes_single_batch(weights, available_device): | ||
| torch.manual_seed(0) | ||
| y_pred = torch.randint(0, 3, size=(30,)).long() | ||
| y = torch.randint(0, 3, size=(30,)).long() | ||
|
|
||
| ck = CohenKappa(weights=weights, device=available_device, num_classes=3) | ||
| ck.reset() | ||
| ck.update((y_pred, y)) | ||
| res = ck.compute() | ||
|
|
||
| assert isinstance(res, float) | ||
| assert cohen_kappa_score(y.numpy(), y_pred.numpy(), weights=weights) == pytest.approx(res) | ||
|
|
||
|
|
||
| def test_num_classes_multilabel_inputs(): | ||
| ck = CohenKappa(num_classes=4) | ||
| with pytest.raises(ValueError, match=r"multilabel-indicator is not supported"): | ||
| ck.reset() | ||
| ck.update((torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long())) | ||
| ck.compute() | ||
|
|
||
|
|
||
| def test_num_classes_squeeze_n1(): | ||
| torch.manual_seed(7) | ||
| y_pred = torch.randint(0, 2, size=(20, 1)).long() | ||
| y = torch.randint(0, 2, size=(20, 1)).long() | ||
|
|
||
| ck = CohenKappa(num_classes=2) | ||
| ck.reset() | ||
| ck.update((y_pred, y)) | ||
| res = ck.compute() | ||
|
|
||
| assert isinstance(res, float) | ||
| assert cohen_kappa_score(y.squeeze().numpy(), y_pred.squeeze().numpy()) == pytest.approx(res) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we fetch double dtype from
cm._double_dtype?