-
-
Notifications
You must be signed in to change notification settings - Fork 270
Expand file tree
/
Copy pathtest_split.py
More file actions
100 lines (86 loc) · 3.29 KB
/
test_split.py
File metadata and controls
100 lines (86 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# License: BSD 3-Clause
from __future__ import annotations
import inspect
import os
import shutil
import tempfile
from pathlib import Path
import numpy as np
from openml import OpenMLSplit
from openml.testing import TestBase
class OpenMLSplitTest(TestBase):
# Splitting not helpful, these test's don't rely on the server and take less
# than 5 seconds + rebuilding the test would potentially be costly
def setUp(self):
__file__ = inspect.getfile(OpenMLSplitTest)
self.directory = os.path.dirname(__file__)
source_arff = (
Path(self.directory).parent
/ "files"
/ "org"
/ "openml"
/ "test"
/ "tasks"
/ "1882"
/ "datasplits.arff"
)
# Use a unique temp directory for each test to avoid race conditions
# when running tests in parallel (see issue #1641)
self._temp_dir = tempfile.TemporaryDirectory()
self.arff_filepath = Path(self._temp_dir.name) / "datasplits.arff"
shutil.copy(source_arff, self.arff_filepath)
self.pd_filename = self.arff_filepath.with_suffix(".pkl.py3")
def tearDown(self):
self._temp_dir.cleanup()
def test_eq(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
assert split == split
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.name = "a"
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.description = "a"
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.split[10] = {}
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.split[0][10] = {}
assert split != split2
def test_from_arff_file(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
assert isinstance(split.split, dict)
assert isinstance(split.split[0], dict)
assert isinstance(split.split[0][0], dict)
assert isinstance(split.split[0][0][0][0], np.ndarray)
assert isinstance(split.split[0][0][0].train, np.ndarray)
assert isinstance(split.split[0][0][0].train, np.ndarray)
assert isinstance(split.split[0][0][0][1], np.ndarray)
assert isinstance(split.split[0][0][0].test, np.ndarray)
assert isinstance(split.split[0][0][0].test, np.ndarray)
for i in range(10):
for j in range(10):
assert split.split[i][j][0].train.shape[0] >= 808
assert split.split[i][j][0].test.shape[0] >= 89
assert (
split.split[i][j][0].train.shape[0] + split.split[i][j][0].test.shape[0] == 898
)
def test_get_split(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
train_split, test_split = split.get(fold=5, repeat=2)
assert train_split.shape[0] == 808
assert test_split.shape[0] == 90
self.assertRaisesRegex(
ValueError,
"Repeat 10 not known",
split.get,
10,
2,
)
self.assertRaisesRegex(
ValueError,
"Fold 10 not known",
split.get,
2,
10,
)