Skip to content
Open
34 changes: 11 additions & 23 deletions tests/test_tasks/test_split.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# License: BSD 3-Clause
from __future__ import annotations

import contextlib
import inspect
import os
import shutil
import tempfile
from pathlib import Path

import numpy as np
import pytest

from openml import OpenMLSplit
from openml.testing import TestBase
Expand All @@ -18,8 +20,11 @@ class OpenMLSplitTest(TestBase):
# than 5 seconds + rebuilding the test would potentially be costly

def setUp(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't change anything here, keep it as it was

Copy link
Copy Markdown
Collaborator

@geetu040 geetu040 Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1694 (comment) is still unresolved

super().setUp()
self._temp_dir = tempfile.TemporaryDirectory()
__file__ = inspect.getfile(OpenMLSplitTest)
self.directory = os.path.dirname(__file__)
# This is for dataset
source_arff = (
Path(self.directory).parent
/ "files"
Expand All @@ -30,23 +35,16 @@ def setUp(self):
/ "1882"
/ "datasplits.arff"
)
# Use a unique temp directory for each test to avoid race conditions
# when running tests in parallel (see issue #1641)
self._temp_dir = tempfile.TemporaryDirectory()
self.arff_filepath = Path(self._temp_dir.name) / "datasplits.arff"
shutil.copy(source_arff, self.arff_filepath)
self.pd_filename = self.arff_filepath.with_suffix(".pkl.py3")

def tearDown(self):
# Clean up the entire temp directory
try:
self._temp_dir.cleanup()
except (OSError, FileNotFoundError):
pass
self._temp_dir.cleanup()

def test_eq(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
assert split == split
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about that, I think this should be kept, since it tests eq

assert split == split # noqa: PLR0124
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line can be removed

Suggested change
assert split == split # noqa: PLR0124


split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.name = "a"
Expand Down Expand Up @@ -88,17 +86,7 @@ def test_get_split(self):
train_split, test_split = split.get(fold=5, repeat=2)
assert train_split.shape[0] == 808
assert test_split.shape[0] == 90
self.assertRaisesRegex(
ValueError,
"Repeat 10 not known",
split.get,
10,
2,
)
self.assertRaisesRegex(
ValueError,
"Fold 10 not known",
split.get,
2,
10,
)
with pytest.raises(ValueError, match="Repeat 10 not known"):
split.get(10, 2)
with pytest.raises(ValueError, match="Fold 10 not known"):
split.get(2, 10)