Skip to content

Commit 3bc275b

Browse files
TheoGuyardmathurinmBadr-MOUFAD
authored
API abstraction layer to fetch datasets (#37)
Co-authored-by: mathurinm <[email protected]> Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent db34488 commit 3bc275b

10 files changed

Lines changed: 563 additions & 431 deletions

File tree

README.rst

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,39 @@
11
|image0| |image1|
22

3-
A python util to fetch datasets from the LIBSVM website.
3+
A python util to fetch datasets from different databases.
44

5+
Currently supported databases are:
6+
7+
- LIBSVM (libsvm_)
58

69
Getting design matrix and target variable is as easy as:
710

811
::
912

10-
from libsvmdata import fetch_libsvm
11-
X, y = fetch_libsvm("news20.binary")
13+
from libsvmdata import fetch_dataset
14+
X, y = fetch_dataset("news20.binary")
15+
16+
Currently supported datasets are in ``libsvmdata.supported`` and can be displayed as:
1217

18+
::
1319

14-
Currently supported datasets are in ``libsvmdata.supported``.
20+
from libsvmdata import print_supported_datasets
21+
print_supported_datasets()
1522

23+
There is no need to specify the database name.
1624

17-
The datasets are saved in a subfolder ``libsvm`` inside ``libsvmdata.datasets.DATA_HOME``, whose value is:
25+
Files are saved under ``DATA_HOME/<database_name>``, where the value of ``DATA_HOME`` is:
1826

19-
- the environment variable LIBSVMDATA_HOME if it exists,
27+
- the environment variable ``LIBSVMDATA_HOME`` if it exists,
2028

21-
- else, the environment variable XDG_DATA_HOME if it exists,
29+
- else, the environment variable ``XDG_DATA_HOME`` if it exists,
2230

23-
- else, $HOME/data.
31+
- else, ``$HOME/data``.
2432

2533

2634

2735
.. |image0| image:: https://github.com/mathurinm/libsvmdata/actions/workflows/build.yml/badge.svg?branch=main
2836
:target: https://github.com/mathurinm/libsvmdata/actions/workflows/build.yml
2937
.. |image1| image:: https://codecov.io/gh/mathurinm/libsvmdata/branch/main/graphs/badge.svg?branch=main
3038
:target: https://codecov.io/gh/mathurinm/libsvmdata
39+
.. _libsvm: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/

libsvmdata/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from .datasets import fetch_libsvm, download_libsvm, NAMES # noqa
1+
from libsvmdata.datasets import fetch_libsvm, download_libsvm
2+
from libsvmdata.core import fetch_dataset, print_supported_datasets, ALL_DATASETS
23

3-
supported = list(NAMES.keys()) # noqa
4+
supported = list(ALL_DATASETS.keys())
45

56
__version__ = '0.5dev0'

libsvmdata/abstraction.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import os
2+
import re
3+
import numpy as np
4+
from abc import ABC, abstractmethod
5+
from download import download
6+
from pathlib import Path
7+
from scipy import sparse
8+
9+
10+
def _get_data_home(subdir_name=""):
11+
"""
12+
Defines the data home folder. The top priority is the environment
13+
variable $LIBSVMDATA_HOME which is specific to this package. Otherwise, we
14+
seek for the variable $XDG_DATA_HOME. Finally, the fallback is $HOME/data.
15+
"""
16+
data_home = os.environ.get("LIBSVMDATA_HOME", None)
17+
if data_home is None:
18+
data_home = os.environ.get("XDG_DATA_HOME", None)
19+
if data_home is None:
20+
data_home = Path.home() / "data"
21+
return data_home / subdir_name
22+
23+
24+
class AbstractDataset(ABC):
25+
"""Base class defining a dataset along with its fetching methods."""
26+
27+
# In the derived class, __init__() must set the following attributes :
28+
dataset_name = None # dataset name
29+
dataset_file = None # dataset file (with potential extensions)
30+
dataset_dir = None # subdirectory name (see _get_data_home())
31+
dataset_url = None # dataset download url
32+
33+
@abstractmethod
34+
def __init__(self):
35+
"""
36+
In the derived class, this function must define the class attributes.
37+
It can also be used to pass additional information required in the
38+
function _load_file_and_save_data() of the derived class.
39+
"""
40+
pass
41+
42+
@abstractmethod
43+
def _load_file_and_save_data(self, raw_dataset_path, ext_dataset_path):
44+
"""
45+
In the derived class, this function is responsible of the
46+
transformation of the raw dataset file into two .npy/.npz files
47+
containing the feature matrix X and the response vector/matrix y. These
48+
files must be named <self.dataset_name>_X.<npz/npy> and
49+
<self.dataset_name>_y.<npz/npy>. This function is also responsible for
50+
removing the raw dataset file when needed.
51+
"""
52+
pass
53+
54+
def _load_data(self, ext_dataset_path):
55+
"""Load data from the extracted .npz/.npy files."""
56+
57+
try:
58+
X = sparse.load_npz(str(ext_dataset_path) + "_X.npz")
59+
except FileNotFoundError:
60+
X = np.load(str(ext_dataset_path) + "_X.npy")
61+
62+
try:
63+
y = sparse.load_npz(str(ext_dataset_path) + "_y.npz")
64+
except FileNotFoundError:
65+
y = np.load(str(ext_dataset_path) + "_y.npy")
66+
67+
return X, y
68+
69+
def get_X_y(self, replace=False, verbose=False):
70+
"""
71+
Load a dataset as matrix X and vector y. If X and y already exist as
72+
.npz and/or .npy files, they are not redownloaded, unless replace=True.
73+
"""
74+
75+
raw_dataset_path = self.dataset_dir / self.dataset_file
76+
ext_dataset_path = self.dataset_dir / self.dataset_name
77+
78+
# Check if the dataset already exists
79+
if self.dataset_dir.exists():
80+
regex = re.compile(f"{self.dataset_name}_(X|y).(npz|npy)")
81+
files = os.listdir(self.dataset_dir)
82+
found = [f for f in files if re.search(regex, f)]
83+
exists = len(found) == 2
84+
else:
85+
found = []
86+
exists = False
87+
88+
if replace or not exists:
89+
90+
# Remove existing dataset files if there are any
91+
if raw_dataset_path.exists():
92+
raw_dataset_path.unlink()
93+
for file in found:
94+
Path(self.dataset_dir / file).unlink()
95+
96+
# Path of the raw dataset file
97+
if verbose:
98+
print("Downloading...")
99+
download(
100+
self.dataset_url,
101+
raw_dataset_path,
102+
progressbar=verbose,
103+
replace=replace,
104+
verbose=verbose,
105+
)
106+
107+
if verbose:
108+
print("Loading file and saving data...")
109+
X, y = self._load_file_and_save_data(
110+
raw_dataset_path,
111+
ext_dataset_path
112+
)
113+
114+
else:
115+
if verbose:
116+
print("Loading data...")
117+
X, y = self._load_data(ext_dataset_path)
118+
119+
return X, y

libsvmdata/core.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from libsvmdata.libsvm import DATASETS as libsvm_datasets
2+
3+
ALL_DATABASES = {"LIBSVM": libsvm_datasets}
4+
5+
ALL_DATASETS = {
6+
dataset.dataset_name: dataset
7+
for datasets in ALL_DATABASES.values()
8+
for dataset in datasets
9+
}
10+
11+
12+
def fetch_dataset(dataset_name, replace=False, verbose=False):
13+
"""
14+
Load a dataset. It is downloaded only if not present or when replace=True.
15+
16+
Parameters
17+
----------
18+
dataset_name : string
19+
Dataset name.
20+
21+
replace : bool, default=False
22+
Whether to re-download the dataset if it is already downloaded.
23+
24+
verbose : bool, default=False
25+
Whether or not to print information about dataset loading.
26+
27+
28+
Returns
29+
-------
30+
X : np.ndarray or scipy.sparse.csc_matrix
31+
Design matrix, as 2D array or column sparse format depending on the
32+
dataset.
33+
34+
y : 1D or 2D np.ndarray
35+
Design vector (or matrix in multiclass setting).
36+
"""
37+
38+
if dataset_name not in ALL_DATASETS.keys():
39+
raise ValueError(
40+
f"Unsupported dataset `{dataset_name}`. Supported datasets can be "
41+
"displayed using the `libsvmdata.print_supported_datasets` "
42+
"function."
43+
)
44+
45+
dataset = ALL_DATASETS[dataset_name]
46+
47+
X, y = dataset.get_X_y(replace=replace, verbose=verbose)
48+
49+
return X, y
50+
51+
52+
def print_supported_datasets():
53+
print("Supported datasets")
54+
for database_name, datasets in ALL_DATABASES.items():
55+
print(f"- {database_name}: ")
56+
print(", ".join(dataset.dataset_name for dataset in datasets))

0 commit comments

Comments
 (0)