-
Notifications
You must be signed in to change notification settings - Fork 122
feat: Add file based Hugging Face datasets to kedro-datasets
#1373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
acf3c3a
382ca46
63eaa54
272a7ac
aee993d
168a747
de65093
011c9c5
3fdc4f7
117706a
b222ccf
a2abdda
73b7cde
ff9131c
209bf24
013b3a8
4affa12
37e56d9
52d4f76
adbf9a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # ArrowDataset | ||
|
|
||
| `ArrowDataset` loads and saves Hugging Face datasets in Arrow format using the `datasets` library. | ||
|
|
||
| ::: kedro_datasets.huggingface.ArrowDataset | ||
| options: | ||
| members: true | ||
| show_source: true |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # CSVDataset | ||
|
|
||
| `CSVDataset` loads and saves Hugging Face datasets in CSV format using the `datasets` library. | ||
|
|
||
| ::: kedro_datasets.huggingface.CSVDataset | ||
| options: | ||
| members: true | ||
| show_source: true |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # JSONDataset | ||
|
|
||
| `JSONDataset` loads and saves Hugging Face datasets in JSON format using the `datasets` library. | ||
|
|
||
| ::: kedro_datasets.huggingface.JSONDataset | ||
| options: | ||
| members: true | ||
| show_source: true |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # ParquetDataset | ||
|
|
||
| `ParquetDataset` loads and saves Hugging Face datasets in Parquet format using the `datasets` library. | ||
|
|
||
| ::: kedro_datasets.huggingface.ParquetDataset | ||
| options: | ||
| members: true | ||
| show_source: true |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| from copy import deepcopy | ||
| from pathlib import PurePosixPath | ||
| from typing import Any, ClassVar, TypeAlias | ||
|
|
||
| import fsspec | ||
| from datasets import ( | ||
| Dataset, | ||
| DatasetDict, | ||
| IterableDataset, | ||
| IterableDatasetDict, | ||
| load_dataset, | ||
| ) | ||
| from kedro.io.core import ( | ||
| AbstractVersionedDataset, | ||
| DatasetError, | ||
| Version, | ||
| get_filepath_str, | ||
| get_protocol_and_path, | ||
| ) | ||
|
|
||
| DatasetLike: TypeAlias = Dataset | DatasetDict | IterableDataset | IterableDatasetDict | ||
|
|
||
|
|
||
| class FilesystemDataset(AbstractVersionedDataset[DatasetLike, DatasetLike]): | ||
| """Base class for Hugging Face dataset types persisted on a filesystem. | ||
|
|
||
| Not intended for direct use — use a format-specific subclass instead | ||
| (e.g. ``ArrowDataset``, ``ParquetDataset``). | ||
| """ | ||
|
|
||
| BUILDER: ClassVar[str] | ||
| EXTENSION: ClassVar[str] | ||
|
|
||
| def __init__( # noqa: PLR0913 | ||
| self, | ||
| *, | ||
| path: str | os.PathLike, | ||
| version: Version | None = None, | ||
| load_args: dict[str, Any] | None = None, | ||
| save_args: dict[str, Any] | None = None, | ||
| credentials: dict[str, Any] | None = None, | ||
| fs_args: dict[str, Any] | None = None, | ||
| metadata: dict[str, Any] | None = None, | ||
| ) -> None: | ||
| """Creates a new instance of ``FilesystemDataset``. | ||
|
|
||
| Args: | ||
| path: Path to a file or directory for persisting Hugging Face | ||
| datasets. Supports local paths, ``os.PathLike`` objects, | ||
| and remote URIs (e.g. ``s3://bucket/data``). | ||
| version: Optional versioning configuration | ||
| (see :class:`~kedro.io.core.Version`). | ||
| load_args: Additional keyword arguments passed to the | ||
| underlying load function. When loading a ``DatasetDict`` | ||
| from a directory, supply ``data_files`` as a mapping of | ||
| split name to filename (e.g. ``{"train": "train.csv"}``). | ||
| The keys must match the split names of the ``DatasetDict`` | ||
| being saved, and the filenames must use the correct | ||
| extension for the format (e.g. ``.csv`` for | ||
| ``CSVDataset``). | ||
| save_args: Additional keyword arguments passed to the | ||
| underlying save function. | ||
| credentials: Credentials for the underlying filesystem | ||
| (e.g. ``key``/``secret`` for S3). Passed to the | ||
| ``storage_options`` parameter in the underlying | ||
| ``datasets`` implementation. | ||
| fs_args: Extra arguments passed to the ``fsspec`` filesystem | ||
| initialiser. Passed to the ``storage_options`` parameter | ||
| in the underlying ``datasets`` implementation. | ||
| metadata: Any arbitrary metadata. This is ignored by Kedro | ||
| but may be consumed by users or external plugins. | ||
| """ | ||
| _fs_args = deepcopy(fs_args) or {} | ||
| _credentials = deepcopy(credentials) or {} | ||
|
|
||
| protocol, resolved_path = get_protocol_and_path(path, version) | ||
| self._protocol = protocol | ||
|
|
||
| if protocol == "file": | ||
| _fs_args.setdefault("auto_mkdir", True) | ||
|
|
||
| self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) | ||
|
|
||
| self._load_args = load_args or {} | ||
| self._save_args = save_args or {} | ||
| self.metadata = metadata | ||
|
|
||
| self._storage_options = {**_credentials, **_fs_args} or None | ||
|
|
||
| super().__init__( | ||
| filepath=PurePosixPath(resolved_path), | ||
| version=version, | ||
| exists_function=self._fs.exists, | ||
| glob_function=self._fs.glob, | ||
| ) | ||
|
|
||
| def load(self) -> DatasetLike: | ||
| load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
| return self._load_dataset(load_path) | ||
|
|
||
| def save(self, data: DatasetLike) -> None: | ||
| if isinstance(data, IterableDataset | IterableDatasetDict): | ||
| msg = ( | ||
| f"{type(self).__name__} got iterable dataset. " | ||
| "Before saving an iterable dataset " | ||
| "you must materialize it into a `Dataset` or `DatasetDict`." | ||
| ) | ||
| raise DatasetError(msg) | ||
|
|
||
| if not isinstance(data, Dataset | DatasetDict): | ||
| msg = ( | ||
| f"{type(self).__name__} only supports `datasets.Dataset`, " | ||
| "`datasets.DatasetDict`, " | ||
| f"Got {type(data)}" | ||
| ) | ||
| raise DatasetError(msg) | ||
|
|
||
| save_path = get_filepath_str(self._get_save_path(), self._protocol) | ||
|
|
||
| if isinstance(data, DatasetDict): | ||
| self._save_dataset_dict(data, save_path) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: Do you need a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I see below it's being read from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a really good point actually... it should be read from However, the way its implemented right now would require a user to specify Concretely, it would have to look like this in yaml: reviews:
type: huggingface.CSVDataset
path: data/01_raw/reviews
load_args:
data_files:
labels: labels.csv
data: data.csv
save_args:
data_files:
labels: labels.csv
data: data.csvWhich isn't the most user friendly... hmm... I guess there are two options: swap the saving operation to read from In this second option, reviews:
type: huggingface.CSVDataset
path: data/01_raw/reviews
data_files:
labels: labels.csv
data: data.csvwould work and so would specifying reviews:
type: huggingface.CSVDataset
path: data/01_raw/reviews
data_files:
labels: labels.csv
data: data.csv
save_args:
data_files:
labels: labels.csv
data: data.csvwould throw an error (regardless of if filenames match or not). What do you think @ankatiyar? Or maybe there's another option you'd prefer. |
||
| else: | ||
| self._save_dataset(data, save_path) | ||
|
|
||
| self._invalidate_cache() | ||
|
|
||
| def _build_data_files(self) -> str | dict[str, str]: | ||
|
iwhalen marked this conversation as resolved.
|
||
| load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
|
|
||
| if "data_files" in self._load_args: | ||
| data_files = self._load_args["data_files"] | ||
| return { | ||
| split: os.path.join(load_path, filename) | ||
| for split, filename in data_files.items() | ||
| } | ||
|
|
||
| # Otherwise, just return the path to the Dataset to be loaded. | ||
| return load_path | ||
|
|
||
| def _load_dataset(self, load_path: str) -> DatasetLike: | ||
| data_files: str | dict[str, str] = self._build_data_files() | ||
|
|
||
| load_args = deepcopy(self._load_args) | ||
| load_args.pop("data_files", None) | ||
|
|
||
| return load_dataset( # nosec | ||
| self.BUILDER, | ||
| data_files=data_files, | ||
| storage_options=self._storage_options, | ||
| **load_args, | ||
| ) | ||
|
|
||
| def _save_dataset(self, data: Dataset, save_path: str) -> None: | ||
| saver = f"to_{self.BUILDER}" | ||
| getattr(data, saver)( | ||
| save_path, | ||
| storage_options=self._storage_options, | ||
| **self._save_args, | ||
| ) | ||
|
|
||
| def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: | ||
| """Hugging Face only provides ``DatasetDict.save_to_disk`` for Arrow format. | ||
|
|
||
| As a result, we have to call ``to_<format>`` per split. | ||
| """ | ||
| if "data_files" in self._load_args: | ||
| data_files_keys = set(self._load_args["data_files"].keys()) | ||
| split_names = set(data.keys()) | ||
| if data_files_keys != split_names: | ||
| msg = ( | ||
| f"DatasetDict split names {sorted(split_names)} do not match " | ||
| f"``load_args['data_files']`` keys {sorted(data_files_keys)}. " | ||
| "The data_files keys must match the DatasetDict split names " | ||
| "so the saved files can be found on load." | ||
| ) | ||
| raise DatasetError(msg) | ||
| self._fs.mkdirs(save_path, exist_ok=True) | ||
| ext = self.EXTENSION | ||
| for split, split_ds in data.items(): | ||
|
iwhalen marked this conversation as resolved.
|
||
| split_path = f"{save_path}/{split}{ext}" | ||
| self._save_dataset(split_ds, split_path) | ||
|
|
||
| def _describe(self) -> dict[str, Any]: | ||
| return { | ||
| "path": self._filepath, | ||
| "file_format": self.BUILDER, | ||
| "protocol": self._protocol, | ||
| "version": self._version, | ||
| "load_args": self._load_args, | ||
| "save_args": self._save_args, | ||
| } | ||
|
|
||
| def _exists(self) -> bool: | ||
| try: | ||
| load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
| except DatasetError: | ||
| return False | ||
| return self._fs.exists(load_path) | ||
|
|
||
| def _release(self) -> None: | ||
| super()._release() | ||
| self._invalidate_cache() | ||
|
|
||
| def _invalidate_cache(self) -> None: | ||
| """Invalidate underlying filesystem caches.""" | ||
| path = get_filepath_str(self._filepath, self._protocol) | ||
| self._fs.invalidate_cache(path) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, ClassVar | ||
|
|
||
| from datasets import Dataset, DatasetDict, load_from_disk | ||
| from kedro.io.core import DatasetError, get_filepath_str | ||
|
|
||
| from ._base import DatasetLike, FilesystemDataset | ||
|
|
||
|
|
||
| class ArrowDataset(FilesystemDataset): | ||
| """``ArrowDataset`` loads/saves Hugging Face ``Dataset`` and | ||
| ``DatasetDict`` objects to/from disk in | ||
| `Arrow <https://huggingface.co/docs/datasets/about_arrow>`_ format | ||
| using ``save_to_disk`` / ``load_from_disk``. | ||
|
|
||
| Saving ``IterableDataset`` or ``IterableDatasetDict`` objects is not | ||
| supported and will raise a ``DatasetError``. Materialize the iterable | ||
| dataset into a ``Dataset`` or ``DatasetDict`` before saving. | ||
|
|
||
| Examples: | ||
| Using the | ||
| [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): | ||
|
|
||
| ```yaml | ||
| reviews: | ||
| type: huggingface.ArrowDataset | ||
| path: data/01_raw/reviews | ||
| ``` | ||
|
|
||
| Using the | ||
| [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): | ||
|
|
||
| >>> from datasets import Dataset | ||
| >>> from kedro_datasets.huggingface.arrow_dataset import ( | ||
| ... ArrowDataset, | ||
| ... ) | ||
| >>> | ||
| >>> data = Dataset.from_dict( | ||
| ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} | ||
| ... ) | ||
| >>> | ||
| >>> dataset = ArrowDataset( | ||
| ... path=tmp_path / "test_hf_dataset" | ||
| ... ) | ||
| >>> dataset.save(data) | ||
| >>> reloaded = dataset.load() | ||
| >>> assert reloaded.to_dict() == data.to_dict() | ||
| """ | ||
|
|
||
| BUILDER: ClassVar[str] = "arrow" | ||
| EXTENSION: ClassVar[str] = ".arrow" | ||
|
|
||
| def _load_dataset(self, load_path: str) -> DatasetLike: | ||
| return load_from_disk( | ||
| load_path, | ||
| storage_options=self._storage_options, | ||
| **self._load_args, | ||
| ) | ||
|
|
||
| def _save_dataset(self, data: Dataset, save_path: str) -> None: | ||
| data.save_to_disk( | ||
| save_path, | ||
| storage_options=self._storage_options, | ||
| **self._save_args, | ||
| ) | ||
|
|
||
| def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: | ||
| data.save_to_disk( | ||
| save_path, | ||
| storage_options=self._storage_options, | ||
| **self._save_args, | ||
| ) | ||
|
|
||
| def _exists(self) -> bool: | ||
| try: | ||
| load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
| except DatasetError: | ||
| return False | ||
|
|
||
| return self._fs.isdir(load_path) and ( | ||
| self._fs.exists(f"{load_path}/dataset_dict.json") | ||
| or self._fs.exists(f"{load_path}/dataset_info.json") | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.