-
-
Notifications
You must be signed in to change notification settings - Fork 270
[MNT] Fix race condition in OpenMLSplit tests during parallel execution #1641 #1694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
447aff1
b29157d
6d0b51d
981213d
d962e74
ebe11c3
f6bac1f
758577e
7f00928
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,13 +1,15 @@ | ||||||||
| # License: BSD 3-Clause | ||||||||
| from __future__ import annotations | ||||||||
|
|
||||||||
| import contextlib | ||||||||
| import inspect | ||||||||
| import os | ||||||||
| import shutil | ||||||||
| import tempfile | ||||||||
| from pathlib import Path | ||||||||
|
|
||||||||
| import numpy as np | ||||||||
| import pytest | ||||||||
|
|
||||||||
| from openml import OpenMLSplit | ||||||||
| from openml.testing import TestBase | ||||||||
|
|
@@ -18,9 +20,12 @@ class OpenMLSplitTest(TestBase): | |||||||
| # than 5 seconds + rebuilding the test would potentially be costly | ||||||||
|
|
||||||||
| def setUp(self): | ||||||||
| super().setUp() | ||||||||
| self.test_dir = tempfile.mkdtemp() | ||||||||
| __file__ = inspect.getfile(OpenMLSplitTest) | ||||||||
| self.directory = os.path.dirname(__file__) | ||||||||
| source_arff = ( | ||||||||
| # This is for dataset | ||||||||
| original_arff_filepath = ( | ||||||||
| Path(self.directory).parent | ||||||||
| / "files" | ||||||||
| / "org" | ||||||||
|
|
@@ -30,23 +35,21 @@ def setUp(self): | |||||||
| / "1882" | ||||||||
| / "datasplits.arff" | ||||||||
| ) | ||||||||
| # Use a unique temp directory for each test to avoid race conditions | ||||||||
| # when running tests in parallel (see issue #1641) | ||||||||
| self._temp_dir = tempfile.TemporaryDirectory() | ||||||||
| self.arff_filepath = Path(self._temp_dir.name) / "datasplits.arff" | ||||||||
| shutil.copy(source_arff, self.arff_filepath) | ||||||||
| self.arff_filepath = Path(self.test_dir) / "datasplits.arff" | ||||||||
| shutil.copy(original_arff_filepath, self.arff_filepath) | ||||||||
| self.pd_filename = self.arff_filepath.with_suffix(".pkl.py3") | ||||||||
|
|
||||||||
| def tearDown(self): | ||||||||
| # Clean up the entire temp directory | ||||||||
| try: | ||||||||
| self._temp_dir.cleanup() | ||||||||
| except (OSError, FileNotFoundError): | ||||||||
| shutil.rmtree(self.test_dir) | ||||||||
| except OSError: | ||||||||
| # Replaced bare except. Not sure why these exceptions are acceptable. | ||||||||
| pass | ||||||||
| super().tearDown() | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
def tearDown(self):
self._temp_dir.cleanup() |
||||||||
|
|
||||||||
| def test_eq(self): | ||||||||
| split = OpenMLSplit._from_arff_file(self.arff_filepath) | ||||||||
| assert split == split | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry about that, I think this should be kept, since it tests eq |
||||||||
| assert split == split # noqa: PLR0124 | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this line can be removed
Suggested change
|
||||||||
|
|
||||||||
| split2 = OpenMLSplit._from_arff_file(self.arff_filepath) | ||||||||
| split2.name = "a" | ||||||||
|
|
@@ -88,17 +91,7 @@ def test_get_split(self): | |||||||
| train_split, test_split = split.get(fold=5, repeat=2) | ||||||||
| assert train_split.shape[0] == 808 | ||||||||
| assert test_split.shape[0] == 90 | ||||||||
| self.assertRaisesRegex( | ||||||||
| ValueError, | ||||||||
| "Repeat 10 not known", | ||||||||
| split.get, | ||||||||
| 10, | ||||||||
| 2, | ||||||||
| ) | ||||||||
| self.assertRaisesRegex( | ||||||||
| ValueError, | ||||||||
| "Fold 10 not known", | ||||||||
| split.get, | ||||||||
| 2, | ||||||||
| 10, | ||||||||
| ) | ||||||||
| with pytest.raises(ValueError, match="Repeat 10 not known"): | ||||||||
| split.get(10, 2) | ||||||||
| with pytest.raises(ValueError, match="Fold 10 not known"): | ||||||||
| split.get(2, 10) | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't change anything here, keep it as it was
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1694 (comment) is still unresolved