-
-
Notifications
You must be signed in to change notification settings - Fork 270
Expand file tree
/
Copy pathtest_split.py
More file actions
97 lines (83 loc) · 3.31 KB
/
test_split.py
File metadata and controls
97 lines (83 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# License: BSD 3-Clause
from __future__ import annotations
import contextlib
import inspect
import os
import shutil
import tempfile
from pathlib import Path
import numpy as np
import pytest
from openml import OpenMLSplit
from openml.testing import TestBase
class OpenMLSplitTest(TestBase):
# Splitting not helpful, these test's don't rely on the server and take less
# than 5 seconds + rebuilding the test would potentially be costly
def setUp(self):
super().setUp()
self.test_dir = tempfile.mkdtemp()
__file__ = inspect.getfile(OpenMLSplitTest)
self.directory = os.path.dirname(__file__)
# This is for dataset
original_arff_filepath = (
Path(self.directory).parent
/ "files"
/ "org"
/ "openml"
/ "test"
/ "tasks"
/ "1882"
/ "datasplits.arff"
)
self.arff_filepath = Path(self.test_dir) / "datasplits.arff"
shutil.copy(original_arff_filepath, self.arff_filepath)
self.pd_filename = self.arff_filepath.with_suffix(".pkl.py3")
def tearDown(self):
try:
shutil.rmtree(self.test_dir)
except OSError:
# Replaced bare except. Not sure why these exceptions are acceptable.
pass
super().tearDown()
def test_eq(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
assert split == split # noqa: PLR0124
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.name = "a"
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.description = "a"
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.split[10] = {}
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.split[0][10] = {}
assert split != split2
def test_from_arff_file(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
assert isinstance(split.split, dict)
assert isinstance(split.split[0], dict)
assert isinstance(split.split[0][0], dict)
assert isinstance(split.split[0][0][0][0], np.ndarray)
assert isinstance(split.split[0][0][0].train, np.ndarray)
assert isinstance(split.split[0][0][0].train, np.ndarray)
assert isinstance(split.split[0][0][0][1], np.ndarray)
assert isinstance(split.split[0][0][0].test, np.ndarray)
assert isinstance(split.split[0][0][0].test, np.ndarray)
for i in range(10):
for j in range(10):
assert split.split[i][j][0].train.shape[0] >= 808
assert split.split[i][j][0].test.shape[0] >= 89
assert (
split.split[i][j][0].train.shape[0] + split.split[i][j][0].test.shape[0] == 898
)
def test_get_split(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
train_split, test_split = split.get(fold=5, repeat=2)
assert train_split.shape[0] == 808
assert test_split.shape[0] == 90
with pytest.raises(ValueError, match="Repeat 10 not known"):
split.get(10, 2)
with pytest.raises(ValueError, match="Fold 10 not known"):
split.get(2, 10)