Skip to content
Open
43 changes: 19 additions & 24 deletions tests/test_tasks/test_split.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# License: BSD 3-Clause
from __future__ import annotations

import contextlib
import inspect
import os
import shutil
import tempfile
from pathlib import Path

import numpy as np
import pytest

from openml import OpenMLSplit
from openml.testing import TestBase
Expand All @@ -16,11 +19,13 @@ class OpenMLSplitTest(TestBase):
# than 5 seconds + rebuilding the test would potentially be costly

def setUp(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't change anything here, keep it as it was

Copy link
Copy Markdown
Collaborator

@geetu040 geetu040 Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1694 (comment) is still unresolved

super().setUp()
self.test_dir = tempfile.mkdtemp()
__file__ = inspect.getfile(OpenMLSplitTest)
self.directory = os.path.dirname(__file__)
self.directory = Path(__file__).parent
# This is for dataset
self.arff_filepath = (
Path(self.directory).parent
original_arff_filepath = (
self.directory.parent
/ "files"
/ "org"
/ "openml"
Expand All @@ -29,18 +34,18 @@ def setUp(self):
/ "1882"
/ "datasplits.arff"
)
self.arff_filepath = Path(self.test_dir) / "datasplits.arff"
shutil.copy(original_arff_filepath, self.arff_filepath)
self.pd_filename = self.arff_filepath.with_suffix(".pkl.py3")

def tearDown(self):
try:
os.remove(self.pd_filename)
except (OSError, FileNotFoundError):
# Replaced bare except. Not sure why these exceptions are acceptable.
pass
with contextlib.suppress(OSError):
shutil.rmtree(self.test_dir)
super().tearDown()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
super().tearDown()
def tearDown(self):
self._temp_dir.cleanup()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tearDown can simply be written as:

    def tearDown(self):
        self._temp_dir.cleanup()


def test_eq(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
assert split == split
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about that, I think this should be kept, since it tests eq

assert split == split # noqa: PLR0124
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line can be removed

Suggested change
assert split == split # noqa: PLR0124


split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.name = "a"
Expand Down Expand Up @@ -82,17 +87,7 @@ def test_get_split(self):
train_split, test_split = split.get(fold=5, repeat=2)
assert train_split.shape[0] == 808
assert test_split.shape[0] == 90
self.assertRaisesRegex(
ValueError,
"Repeat 10 not known",
split.get,
10,
2,
)
self.assertRaisesRegex(
ValueError,
"Fold 10 not known",
split.get,
2,
10,
)
with pytest.raises(ValueError, match="Repeat 10 not known"):
split.get(10, 2)
with pytest.raises(ValueError, match="Fold 10 not known"):
split.get(2, 10)