From acf3c3a1a087af748824f757402c2d9fd8b9072e Mon Sep 17 00:00:00 2001 From: iwhalen Date: Sun, 5 Apr 2026 14:51:50 -0500 Subject: [PATCH 01/21] Add huggingface.LocalHFDataset. Signed-off-by: iwhalen --- kedro-datasets/RELEASE.md | 11 + .../kedro_datasets/huggingface/__init__.py | 5 +- .../huggingface/hugging_face_dataset.py | 284 ++++++++++++++- .../huggingface/test_hugging_face_dataset.py | 338 ++++++++++++++++++ 4 files changed, 633 insertions(+), 5 deletions(-) diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 3be87e2f8..bb2b7137e 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -44,6 +44,17 @@ Many thanks to the following Kedroids for contributing PRs to this release: ## Major features and improvements +- Added `huggingface.LocalHFDataset` to handle saving and loading from Hugging Face datasets on a filesystem. + +## Bug fixes and other changes +## Community contributions + +[iwhalen](https://github.com/iwhalen) + +# Release 9.3.0 + +## Major features and improvements + - Kedro-Datasets is now compatible with pandas 3.0. - Added `ibis-materialize` and `ibis-singlestoredb` extras for the backends added in Ibis 12.0. - Added "upsert" save mode to `ibis.TableDataset` (available on backends that support `MERGE INTO` since Ibis 12.0). diff --git a/kedro-datasets/kedro_datasets/huggingface/__init__.py b/kedro-datasets/kedro_datasets/huggingface/__init__.py index be777ccec..b9f906c46 100644 --- a/kedro-datasets/kedro_datasets/huggingface/__init__.py +++ b/kedro-datasets/kedro_datasets/huggingface/__init__.py @@ -5,11 +5,12 @@ import lazy_loader as lazy try: - from .hugging_face_dataset import HFDataset + from .hugging_face_dataset import HFDataset, LocalHFDataset except (ImportError, RuntimeError): # For documentation builds that might fail due to dependency issues # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 HFDataset: Any + LocalHFDataset: Any try: from .transformer_pipeline_dataset import HFTransformerPipelineDataset @@ -21,7 +22,7 @@ __getattr__, __dir__, __all__ = lazy.attach( __name__, submod_attrs={ - "hugging_face_dataset": ["HFDataset"], + "hugging_face_dataset": ["HFDataset", "LocalHFDataset"], "transformer_pipeline_dataset": ["HFTransformerPipelineDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py index 7a451dba0..737697060 100644 --- a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py @@ -1,13 +1,33 @@ from __future__ import annotations +import os +from copy import deepcopy +from pathlib import PurePosixPath from typing import Any -from datasets import load_dataset +import fsspec +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + load_dataset, + load_from_disk, +) from huggingface_hub import HfApi from kedro.io import AbstractDataset +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_filepath_str, + get_protocol_and_path, +) +DatasetLike = Dataset | DatasetDict | IterableDataset | IterableDatasetDict -class HFDataset(AbstractDataset): + +class HFDataset(AbstractDataset[None, DatasetLike]): """``HFDataset`` loads Hugging Face datasets using the `datasets `_ library. @@ -45,7 +65,7 @@ def __init__( self._dataset_kwargs = dataset_kwargs or {} self.metadata = metadata - def load(self): + def load(self) -> DatasetLike: # TODO: Replace suppression with the solution from here: https://github.com/kedro-org/kedro-plugins/issues/1131 return load_dataset(self.dataset_name, **self._dataset_kwargs) # nosec @@ -65,3 +85,261 @@ def _describe(self) -> dict[str, Any]: def list_datasets(): api = HfApi() return list(api.list_datasets()) + + +class LocalHFDataset(AbstractVersionedDataset[DatasetLike, DatasetLike]): + """``LocalHFDataset`` loads/saves Hugging Face ``Dataset``, + ``DatasetDict``, ``IterableDataset``, and + ``IterableDatasetDict`` objects to/from disk using an + underlying filesystem (e.g.: local, S3, GCS). Iterable + variants are materialized before saving. + + Examples: + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + with a ``datasets.Dataset`` + + ```yaml + reviews: + type: huggingface.LocalHFDataset + path: data/01_raw/reviews.arrow + ``` + + By default, data will be loaded and saved from + (Arrow)[https://huggingface.co/docs/datasets/about_arrow] format. + + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + with a ``datasets.DatasetDict`` in JSON format: + + ```yaml + review_dict: + type: huggingface.LocalHFDataset + path: data/01_raw/review_dict/ + file_format: json + ``` + + This saves each individual ``datasets.Dataset`` into separate files + in the directory in JSON format. + + The ``file_format`` accepts the following arguments: + + - ``arrow`` + - ``parquet`` + - ``json`` + - ``csv`` + - ``lance`` + - ``hdf5`` + + For more on saving and loading from a filesystem with the Datasets + library, see + [here](https://huggingface.co/docs/datasets/v4.8.4/en/loading#local-and-remote-files). + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + with a ``datasets.Dataset``: + + >>> from datasets import Dataset + >>> from kedro_datasets.huggingface.hugging_face_dataset import ( + ... LocalHFDataset, + ... ) + >>> + >>> data = Dataset.from_dict( + ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + ... ) + >>> + >>> dataset = LocalHFDataset( + ... path=tmp_path / "test_hf_dataset.arrow" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert reloaded.to_dict() == data.to_dict() + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + with a ``datasets.DatasetDict``: + + >>> from datasets import Dataset, DatasetDict + >>> from kedro_datasets.huggingface.hugging_face_dataset import ( + ... LocalHFDataset, + ... ) + >>> + >>> data = DatasetDict({ + ... "train": Dataset.from_dict( + ... {"col1": [1, 2], "col2": ["a", "b"]} + ... ), + ... "test": Dataset.from_dict( + ... {"col1": [3], "col2": ["c"]} + ... ), + ... }) + >>> + >>> dataset = LocalHFDataset( + ... path=tmp_path / "test_hf_dataset_dict" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert list(reloaded.keys()) == ["train", "test"] + + """ + + _SUPPORTED_FORMATS = {"arrow", "parquet", "json", "csv", "lance", "hdf5"} + _FORMAT_EXTENSIONS = { + "arrow": ".arrow", + "parquet": ".parquet", + "json": ".json", + "csv": ".csv", + "lance": ".lance", + "hdf5": ".h5", + } + + def __init__( # noqa: PLR0913 + self, + *, + path: str | os.PathLike, + file_format: str = "arrow", + version: Version | None = None, + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + credentials: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + if file_format not in self._SUPPORTED_FORMATS: + msg = ( + f"Unsupported file_format '{file_format}'. " + f"Must be one of {sorted(self._SUPPORTED_FORMATS)}." + ) + raise DatasetError(msg) + + self._file_format = file_format + _fs_args = deepcopy(fs_args) or {} + _credentials = deepcopy(credentials) or {} + + protocol, resolved_path = get_protocol_and_path(path, version) + self._protocol = protocol + + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) + + self._load_args = load_args or {} + self._save_args = save_args or {} + self.metadata = metadata + + # storage_options passed to HF's load/save methods + self._storage_options = {**_credentials, **_fs_args} or None + + super().__init__( + filepath=PurePosixPath(resolved_path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + def _load(self) -> DatasetLike: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + + if self._file_format == "arrow": + return load_from_disk( + load_path, + storage_options=self._storage_options, + **self._load_args, + ) + + ext = self._FORMAT_EXTENSIONS[self._file_format] + loader = getattr(Dataset, f"from_{self._file_format}") + + if self._fs.isdir(load_path): + paths = { + PurePosixPath(p).stem: p for p in self._fs.glob(f"{load_path}/*{ext}") + } + return DatasetDict( + { + split: loader(path, **self._load_args) + for split, path in paths.items() + } + ) + + return loader(load_path, **self._load_args) + + def _save(self, data: DatasetLike) -> None: + if not isinstance( + data, + Dataset | DatasetDict | IterableDataset | IterableDatasetDict, + ): + msg = ( + "LocalHFDataset only supports `datasets.Dataset`, " + "`datasets.DatasetDict`, " + "`datasets.IterableDataset`, and " + "`datasets.IterableDatasetDict` instances. " + f"Got {type(data)}" + ) + raise DatasetError(msg) + + if isinstance(data, IterableDatasetDict): + data = DatasetDict({k: Dataset.from_list(list(v)) for k, v in data.items()}) + elif isinstance(data, IterableDataset): + data = Dataset.from_list(list(data)) + + save_path = get_filepath_str(self._get_save_path(), self._protocol) + + if self._file_format == "arrow": + data.save_to_disk( + save_path, + storage_options=self._storage_options, + **self._save_args, + ) + elif isinstance(data, DatasetDict): + self._fs.mkdirs(save_path, exist_ok=True) + ext = self._FORMAT_EXTENSIONS[self._file_format] + saver = f"to_{self._file_format}" + for split, split_ds in data.items(): + split_path = f"{save_path}/{split}{ext}" + getattr(split_ds, saver)( + split_path, + storage_options=self._storage_options, + **self._save_args, + ) + else: + saver = f"to_{self._file_format}" + getattr(data, saver)( + save_path, + storage_options=self._storage_options, + **self._save_args, + ) + + self._invalidate_cache() + + def _describe(self) -> dict[str, Any]: + return { + "path": self._filepath, + "file_format": self._file_format, + "protocol": self._protocol, + "version": self._version, + "load_args": self._load_args, + "save_args": self._save_args, + } + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + + if self._file_format == "arrow": + return self._fs.isdir(load_path) and ( + self._fs.exists(f"{load_path}/dataset_dict.json") + or self._fs.exists(f"{load_path}/dataset_info.json") + ) + + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + path = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(path) diff --git a/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py b/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py index 909362ec2..66babe0a9 100644 --- a/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py +++ b/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py @@ -1,7 +1,16 @@ +from pathlib import PurePosixPath + import pytest +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem from huggingface_hub import HfApi +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem from kedro_datasets.huggingface import HFDataset +from kedro_datasets.huggingface.hugging_face_dataset import LocalHFDataset @pytest.fixture @@ -31,3 +40,332 @@ def test_list_datasets(self, mocker): datasets = HFDataset.list_datasets() assert datasets == expected_datasets + + +@pytest.fixture +def path_local_hf(tmp_path): + return (tmp_path / "test_hf_dataset").as_posix() + + +@pytest.fixture +def local_hf_dataset(path_local_hf, save_args, load_args, fs_args): + return LocalHFDataset( + path=path_local_hf, + save_args=save_args, + load_args=load_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def versioned_local_hf_dataset(path_local_hf, load_version, save_version): + return LocalHFDataset( + path=path_local_hf, version=Version(load_version, save_version) + ) + + +@pytest.fixture +def dataset(): + return Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + + +@pytest.fixture +def dataset_dict(): + return DatasetDict( + { + "train": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), + "test": Dataset.from_dict({"col1": [3], "col2": ["c"]}), + } + ) + + +@pytest.fixture +def iterable_dataset(): + return Dataset.from_dict( + {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + ).to_iterable_dataset() + + +@pytest.fixture +def iterable_dataset_dict(): + return IterableDatasetDict( + { + "train": Dataset.from_dict( + {"col1": [1, 2], "col2": ["a", "b"]} + ).to_iterable_dataset(), + "test": Dataset.from_dict( + {"col1": [3], "col2": ["c"]} + ).to_iterable_dataset(), + } + ) + + +@pytest.fixture +def parquet_local_hf_dataset(path_local_hf): + return LocalHFDataset(path=path_local_hf, file_format="parquet") + + +class TestLocalHFDataset: + def test_save_and_load_dataset(self, local_hf_dataset, dataset): + """Test saving and reloading a Dataset.""" + local_hf_dataset.save(dataset) + reloaded = local_hf_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, local_hf_dataset, dataset_dict): + """Test saving and reloading a DatasetDict.""" + local_hf_dataset.save(dataset_dict) + reloaded = local_hf_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_exists(self, local_hf_dataset, dataset): + """Test `exists` method for both existing and nonexistent dataset.""" + assert not local_hf_dataset.exists() + local_hf_dataset.save(dataset) + assert local_hf_dataset.exists() + + def test_exists_dataset_dict(self, local_hf_dataset, dataset_dict): + """Test `exists` method for DatasetDict (checks dataset_dict.json marker).""" + assert not local_hf_dataset.exists() + local_hf_dataset.save(dataset_dict) + assert local_hf_dataset.exists() + + def test_save_and_load_iterable_dataset(self, local_hf_dataset, iterable_dataset): + """Test saving an IterableDataset materializes and round-trips.""" + local_hf_dataset.save(iterable_dataset) + reloaded = local_hf_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, local_hf_dataset, iterable_dataset_dict + ): + """Test saving an IterableDatasetDict materializes and round-trips.""" + local_hf_dataset.save(iterable_dataset_dict) + reloaded = local_hf_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + assert reloaded["train"].to_dict() == { + "col1": [1, 2], + "col2": ["a", "b"], + } + + def test_save_and_load_parquet(self, parquet_local_hf_dataset, dataset): + """Test saving and reloading a Dataset as parquet.""" + parquet_local_hf_dataset.save(dataset) + reloaded = parquet_local_hf_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_parquet_dataset_dict( + self, parquet_local_hf_dataset, dataset_dict + ): + """Test saving and reloading a DatasetDict as parquet.""" + parquet_local_hf_dataset.save(dataset_dict) + reloaded = parquet_local_hf_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_exists_parquet(self, parquet_local_hf_dataset, dataset): + """Test `exists` for parquet format.""" + assert not parquet_local_hf_dataset.exists() + parquet_local_hf_dataset.save(dataset) + assert parquet_local_hf_dataset.exists() + + def test_save_and_load_json_dataset_dict(self, tmp_path, dataset_dict): + """Test saving and reloading a DatasetDict as JSON.""" + path = (tmp_path / "test_json_dd").as_posix() + ds = LocalHFDataset(path=path, file_format="json") + ds.save(dataset_dict) + reloaded = ds.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_invalid_file_format(self): + """Test that an unsupported file_format raises DatasetError.""" + pattern = r"Unsupported file_format" + with pytest.raises(DatasetError, match=pattern): + LocalHFDataset(path="test", file_format="xml") + + @pytest.mark.parametrize("save_args", [{"num_shards": 2}], indirect=True) + def test_save_extra_params(self, local_hf_dataset, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert local_hf_dataset._save_args[key] == value + + @pytest.mark.parametrize("load_args", [{"keep_in_memory": True}], indirect=True) + def test_load_extra_params(self, local_hf_dataset, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert local_hf_dataset._load_args[key] == value + + def test_load_missing_dataset(self, local_hf_dataset): + """Check the error when trying to load missing dataset.""" + pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + local_hf_dataset.load() + + def test_save_invalid_type(self, local_hf_dataset): + """Check the error when saving an unsupported type.""" + pattern = r"LocalHFDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." + with pytest.raises(DatasetError, match=pattern): + local_hf_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/hf_data", S3FileSystem), + ("file:///tmp/hf_data", LocalFileSystem), + ("/tmp/hf_data", LocalFileSystem), + ("gcs://bucket/hf_data", GCSFileSystem), + ("https://example.com/hf_data", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + dataset = LocalHFDataset(path=path) + assert isinstance(dataset._fs, instance_type) + + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(dataset._filepath) == resolved + assert isinstance(dataset._filepath, PurePosixPath) + + def test_pathlike_path(self, tmp_path, dataset): + """Test that os.PathLike paths are supported.""" + path = tmp_path / "test_hf_pathlike" + ds = LocalHFDataset(path=path) + ds.save(dataset) + reloaded = ds.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = "test_hf" + dataset = LocalHFDataset(path=path) + dataset.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + +class TestLocalHFDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + path = "test_hf_dataset" + ds = LocalHFDataset(path=path) + ds_versioned = LocalHFDataset( + path=path, version=Version(load_version, save_version) + ) + assert path in str(ds) + assert "version" not in str(ds) + + assert path in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "LocalHFDataset" in str(ds_versioned) + assert "LocalHFDataset" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + + def test_save_and_load(self, versioned_local_hf_dataset, dataset): + """Test that saved and reloaded data matches the original one for + the versioned dataset.""" + versioned_local_hf_dataset.save(dataset) + reloaded = versioned_local_hf_dataset.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, versioned_local_hf_dataset, dataset_dict): + """Test versioned save and reload with DatasetDict.""" + versioned_local_hf_dataset.save(dataset_dict) + reloaded = versioned_local_hf_dataset.load() + assert isinstance(reloaded, DatasetDict) + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_save_and_load_iterable_dataset( + self, versioned_local_hf_dataset, iterable_dataset + ): + """Test versioned save with IterableDataset.""" + versioned_local_hf_dataset.save(iterable_dataset) + reloaded = versioned_local_hf_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, versioned_local_hf_dataset, iterable_dataset_dict + ): + """Test versioned save with IterableDatasetDict.""" + versioned_local_hf_dataset.save(iterable_dataset_dict) + reloaded = versioned_local_hf_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + + def test_no_versions(self, versioned_local_hf_dataset): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_local_hf_dataset.load() + + def test_exists(self, versioned_local_hf_dataset, dataset): + """Test `exists` method invocation for versioned dataset.""" + assert not versioned_local_hf_dataset.exists() + versioned_local_hf_dataset.save(dataset) + assert versioned_local_hf_dataset.exists() + + def test_prevent_overwrite(self, versioned_local_hf_dataset, dataset): + """Check the error when attempting to override the dataset if the + corresponding version already exists.""" + versioned_local_hf_dataset.save(dataset) + pattern = ( + r"Save path \'.+\' for kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_local_hf_dataset.save(dataset) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_local_hf_dataset, load_version, save_version, dataset + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + f"Save version '{save_version}' did not match " + f"load version '{load_version}' for " + r"kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_local_hf_dataset.save(dataset) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + + with pytest.raises(DatasetError, match=pattern): + LocalHFDataset( + path="https://example.com/hf_data", + version=Version(None, None), + ) + + def test_save_invalid_type_versioned(self, versioned_local_hf_dataset): + """Check the error when saving an unsupported type through versioned dataset.""" + pattern = r"LocalHFDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." + with pytest.raises(DatasetError, match=pattern): + versioned_local_hf_dataset.save("not a dataset") From 382ca46c939709dccee3230315cf4297eb9e1b2f Mon Sep 17 00:00:00 2001 From: iwhalen Date: Sun, 5 Apr 2026 15:07:06 -0500 Subject: [PATCH 02/21] Add docs. Signed-off-by: iwhalen --- .../huggingface.LocalHFDataset.md | 9 +++++ .../huggingface/hugging_face_dataset.py | 12 ++----- kedro-datasets/mkdocs.yml | 6 ++-- .../jsonschema/kedro-catalog-1.0.0.json | 33 +++++++++++++++++++ 4 files changed, 49 insertions(+), 11 deletions(-) create mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md new file mode 100644 index 000000000..8e466a5cb --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md @@ -0,0 +1,9 @@ +# LocalHFDataset + +`LocalHFDataset` saves and loads Hugging Face datasets from a filesystem + using the `datasets` library. + +::: kedro_datasets.huggingface.LocalHFDataset + options: + members: true + show_source: true diff --git a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py index 737697060..01314e5a2 100644 --- a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py @@ -106,7 +106,7 @@ class LocalHFDataset(AbstractVersionedDataset[DatasetLike, DatasetLike]): ``` By default, data will be loaded and saved from - (Arrow)[https://huggingface.co/docs/datasets/about_arrow] format. + [Arrow](https://huggingface.co/docs/datasets/about_arrow) format. Using the [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) @@ -122,14 +122,8 @@ class LocalHFDataset(AbstractVersionedDataset[DatasetLike, DatasetLike]): This saves each individual ``datasets.Dataset`` into separate files in the directory in JSON format. - The ``file_format`` accepts the following arguments: - - - ``arrow`` - - ``parquet`` - - ``json`` - - ``csv`` - - ``lance`` - - ``hdf5`` + The ``file_format`` accepts `arrow`, `parquet`, `json`, `csv`, `lance`, + and `hdf5` as arguments. For more on saving and loading from a filesystem with the Datasets library, see diff --git a/kedro-datasets/mkdocs.yml b/kedro-datasets/mkdocs.yml index 1b74d5bfd..61a8fc3d0 100644 --- a/kedro-datasets/mkdocs.yml +++ b/kedro-datasets/mkdocs.yml @@ -140,8 +140,9 @@ plugins: Machine Learning and AI: - api/kedro_datasets/tensorflow.TensorFlowModelDataset.md: TensorFlow model storage - - api/kedro_datasets/huggingface.HFDataset.md: HuggingFace datasets integration - - api/kedro_datasets/huggingface.HFTransformerPipelineDataset.md: HuggingFace transformer pipelines + - api/kedro_datasets/huggingface.HFDataset.md: Hugging Face remote datasets integration + - api/kedro_datasets/huggingface.HFTransformerPipelineDataset.md: Hugging Face transformer pipelines + - api/kedro_datasets/huggingface.LocalHFDataset.md: Hugging Face local datasets integration Visualization and Plotting: - api/kedro_datasets/matplotlib.MatplotlibDataset.md: Matplotlib figure storage @@ -266,6 +267,7 @@ nav: - Huggingface: - huggingface.HFDataset: api/kedro_datasets/huggingface.HFDataset.md - huggingface.HFTransformerPipelineDataset: api/kedro_datasets/huggingface.HFTransformerPipelineDataset.md + - huggingface.LocalHFDataset: api/kedro_datasets/huggingface.LocalHFDataset.md - Ibis: - ibis.FileDataset: api/kedro_datasets/ibis.FileDataset.md - ibis.TableDataset: api/kedro_datasets/ibis.TableDataset.md diff --git a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json index 74cbe3dfa..315ef6629 100644 --- a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json +++ b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json @@ -19,6 +19,7 @@ "holoviews.HoloviewsWriter", "huggingface.HFDataset", "huggingface.HFTransformerPipelineDataset", + "huggingface.LocalHFDataset", "ibis.FileDataset", "ibis.TableDataset", "json.JSONDataset", @@ -451,6 +452,38 @@ } } }, + { + "if": { + "properties": { + "type": { + "const": "huggingface.LocalHFDataset" + } + } + }, + "then": { + "required": ["path"], + "properties": { + "path": { + "type": "string", + "description": "Path to a directory or file for persisting Hugging Face datasets. Supports local and remote filesystems (e.g. s3://)." + }, + "file_format": { + "type": "string", + "enum": ["arrow", "parquet", "json", "csv", "lance", "hdf5"], + "default": "arrow", + "description": "The file format to use for saving and loading. Defaults to 'arrow'." + }, + "load_args": { + "type": "object", + "description": "Additional arguments passed to the load method." + }, + "save_args": { + "type": "object", + "description": "Additional arguments passed to the save method." + } + } + } + }, { "if": { "properties": { From 63eaa540ec050600d15a467e5bee285184f67d47 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Sun, 5 Apr 2026 15:23:49 -0500 Subject: [PATCH 03/21] Add TypeAlias to DatasetLike. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/hugging_face_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py index 01314e5a2..b56e55b2c 100644 --- a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py @@ -3,7 +3,7 @@ import os from copy import deepcopy from pathlib import PurePosixPath -from typing import Any +from typing import Any, TypeAlias import fsspec from datasets import ( @@ -24,7 +24,7 @@ get_protocol_and_path, ) -DatasetLike = Dataset | DatasetDict | IterableDataset | IterableDatasetDict +DatasetLike: TypeAlias = Dataset | DatasetDict | IterableDataset | IterableDatasetDict class HFDataset(AbstractDataset[None, DatasetLike]): From 272a7ac45fafce6a49ccaf616ce94c893fccb062 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Sat, 11 Apr 2026 11:24:27 -0500 Subject: [PATCH 04/21] Break new HF datasets into multiple files, address PR comments. Signed-off-by: iwhalen --- kedro-datasets/RELEASE.md | 6 +- .../kedro_datasets/huggingface/__init__.py | 41 ++- .../kedro_datasets/huggingface/_base.py | 203 +++++++++++ .../huggingface/arrow_dataset.py | 83 +++++ .../kedro_datasets/huggingface/csv_dataset.py | 46 +++ .../huggingface/hdf5_dataset.py | 39 ++ .../huggingface/hugging_face_dataset.py | 264 -------------- .../huggingface/json_dataset.py | 46 +++ .../huggingface/lance_dataset.py | 39 ++ .../huggingface/parquet_dataset.py | 47 +++ .../jsonschema/kedro-catalog-1.0.0.json | 22 +- kedro-datasets/tests/huggingface/conftest.py | 39 ++ .../tests/huggingface/test_arrow_dataset.py | 256 +++++++++++++ .../tests/huggingface/test_csv_dataset.py | 182 ++++++++++ .../tests/huggingface/test_hdf5_dataset.py | 74 ++++ .../huggingface/test_hugging_face_dataset.py | 338 ------------------ .../tests/huggingface/test_json_dataset.py | 182 ++++++++++ .../tests/huggingface/test_lance_dataset.py | 74 ++++ .../tests/huggingface/test_parquet_dataset.py | 184 ++++++++++ 19 files changed, 1550 insertions(+), 615 deletions(-) create mode 100644 kedro-datasets/kedro_datasets/huggingface/_base.py create mode 100644 kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py create mode 100644 kedro-datasets/kedro_datasets/huggingface/csv_dataset.py create mode 100644 kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py create mode 100644 kedro-datasets/kedro_datasets/huggingface/json_dataset.py create mode 100644 kedro-datasets/kedro_datasets/huggingface/lance_dataset.py create mode 100644 kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py create mode 100644 kedro-datasets/tests/huggingface/test_arrow_dataset.py create mode 100644 kedro-datasets/tests/huggingface/test_csv_dataset.py create mode 100644 kedro-datasets/tests/huggingface/test_hdf5_dataset.py create mode 100644 kedro-datasets/tests/huggingface/test_json_dataset.py create mode 100644 kedro-datasets/tests/huggingface/test_lance_dataset.py create mode 100644 kedro-datasets/tests/huggingface/test_parquet_dataset.py diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index bb2b7137e..54e3d75f7 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -44,12 +44,14 @@ Many thanks to the following Kedroids for contributing PRs to this release: ## Major features and improvements -- Added `huggingface.LocalHFDataset` to handle saving and loading from Hugging Face datasets on a filesystem. +- Replaced `huggingface.LocalHFDataset` with per-format classes: `ArrowDataset`, `ParquetDataset`, `JSONDataset`, `CSVDataset`, `LanceDataset`, `HDF5Dataset`. ## Bug fixes and other changes ## Community contributions -[iwhalen](https://github.com/iwhalen) +Many thanks to the following Kedroids for contributing PRs to this release: + +- [iwhalen](https://github.com/iwhalen) # Release 9.3.0 diff --git a/kedro-datasets/kedro_datasets/huggingface/__init__.py b/kedro-datasets/kedro_datasets/huggingface/__init__.py index b9f906c46..f2327b485 100644 --- a/kedro-datasets/kedro_datasets/huggingface/__init__.py +++ b/kedro-datasets/kedro_datasets/huggingface/__init__.py @@ -5,12 +5,41 @@ import lazy_loader as lazy try: - from .hugging_face_dataset import HFDataset, LocalHFDataset + from .hugging_face_dataset import HFDataset except (ImportError, RuntimeError): # For documentation builds that might fail due to dependency issues # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 HFDataset: Any - LocalHFDataset: Any + +try: + from .arrow_dataset import ArrowDataset +except (ImportError, RuntimeError): + ArrowDataset: Any + +try: + from .parquet_dataset import ParquetDataset +except (ImportError, RuntimeError): + ParquetDataset: Any + +try: + from .json_dataset import JSONDataset +except (ImportError, RuntimeError): + JSONDataset: Any + +try: + from .csv_dataset import CSVDataset +except (ImportError, RuntimeError): + CSVDataset: Any + +try: + from .lance_dataset import LanceDataset +except (ImportError, RuntimeError): + LanceDataset: Any + +try: + from .hdf5_dataset import HDF5Dataset +except (ImportError, RuntimeError): + HDF5Dataset: Any try: from .transformer_pipeline_dataset import HFTransformerPipelineDataset @@ -22,7 +51,13 @@ __getattr__, __dir__, __all__ = lazy.attach( __name__, submod_attrs={ - "hugging_face_dataset": ["HFDataset", "LocalHFDataset"], + "hugging_face_dataset": ["HFDataset"], + "arrow_dataset": ["ArrowDataset"], + "parquet_dataset": ["ParquetDataset"], + "json_dataset": ["JSONDataset"], + "csv_dataset": ["CSVDataset"], + "lance_dataset": ["LanceDataset"], + "hdf5_dataset": ["HDF5Dataset"], "transformer_pipeline_dataset": ["HFTransformerPipelineDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/huggingface/_base.py b/kedro-datasets/kedro_datasets/huggingface/_base.py new file mode 100644 index 000000000..59877084b --- /dev/null +++ b/kedro-datasets/kedro_datasets/huggingface/_base.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import os +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any, ClassVar, TypeAlias + +import fsspec +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + load_dataset, +) +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_filepath_str, + get_protocol_and_path, +) + +DatasetLike: TypeAlias = Dataset | DatasetDict | IterableDataset | IterableDatasetDict + + +class FilesystemDataset(AbstractVersionedDataset[DatasetLike, DatasetLike]): + """Base class for Hugging Face dataset types persisted on a filesystem. + + Not intended for direct use — use a format-specific subclass instead + (e.g. ``ArrowDataset``, ``ParquetDataset``). + """ + + BUILDER: ClassVar[str] + EXTENSION: ClassVar[str] + + def __init__( # noqa: PLR0913 + self, + *, + path: str | os.PathLike, + version: Version | None = None, + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + credentials: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Creates a new instance of ``FilesystemDataset``. + + Args: + path: Path to a file or directory for persisting Hugging Face + datasets. Supports local paths, ``os.PathLike`` objects, + and remote URIs (e.g. ``s3://bucket/data``). + version: Optional versioning configuration + (see :class:`~kedro.io.core.Version`). + load_args: Additional keyword arguments passed to the + underlying load function. + save_args: Additional keyword arguments passed to the + underlying save function. + credentials: Credentials for the underlying filesystem + (e.g. ``key``/``secret`` for S3). Passed to the + ``storage_options`` parameter in the underlying + ``datasets`` implementation. + fs_args: Extra arguments passed to the ``fsspec`` filesystem + initialiser. Passed to the ``storage_options`` parameter + in the underlying ``datasets`` implementation. + metadata: Any arbitrary metadata. This is ignored by Kedro + but may be consumed by users or external plugins. + """ + _fs_args = deepcopy(fs_args) or {} + _credentials = deepcopy(credentials) or {} + + protocol, resolved_path = get_protocol_and_path(path, version) + self._protocol = protocol + + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) + + self._load_args = load_args or {} + self._save_args = save_args or {} + self.metadata = metadata + + self._storage_options = {**_credentials, **_fs_args} or None + + super().__init__( + filepath=PurePosixPath(resolved_path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + def load(self) -> DatasetLike: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + return self._load_dataset(load_path) + + def save(self, data: DatasetLike) -> None: + if not isinstance( + data, + Dataset | DatasetDict | IterableDataset | IterableDatasetDict, + ): + msg = ( + f"{type(self).__name__} only supports `datasets.Dataset`, " + "`datasets.DatasetDict`, " + "`datasets.IterableDataset`, and " + "`datasets.IterableDatasetDict` instances. " + f"Got {type(data)}" + ) + raise DatasetError(msg) + + if isinstance(data, IterableDatasetDict): + data = DatasetDict({k: Dataset.from_list(list(v)) for k, v in data.items()}) + elif isinstance(data, IterableDataset): + data = Dataset.from_list(list(data)) + + save_path = get_filepath_str(self._get_save_path(), self._protocol) + + if isinstance(data, DatasetDict): + self._save_dataset_dict(data, save_path) + else: + self._save_dataset(data, save_path) + + self._invalidate_cache() + + def _load_dataset(self, load_path: str) -> DatasetLike: + if self._fs.isdir(load_path): + ext = self.EXTENSION + data_files = { + PurePosixPath(p).stem: p for p in self._fs.glob(f"{load_path}/*{ext}") + } + # Note: nosec is fine here since we're always loading from a filesystem. + # Bandit throws an exception because it wants a revision number, + # which is only relevatn when loading from the Hub. + return load_dataset( + self.BUILDER, data_files=data_files, **self._load_args + ) # nosec + + result = load_dataset( # nosec + self.BUILDER, + data_files=load_path, + storage_options=self._storage_options, + **self._load_args, + ) + + # load_dataset wraps a single file in a DatasetDict with one + # split (typically "train"). When the caller didn't ask for a + # specific split, unwrap it so a single file round-trips as a + # Dataset, not a DatasetDict. + if ( + "split" not in self._load_args + and isinstance(result, DatasetDict) + and len(result) == 1 + ): + return next(iter(result.values())) + + return result + + def _save_dataset(self, data: Dataset, save_path: str) -> None: + saver = f"to_{self.BUILDER}" + getattr(data, saver)( + save_path, + storage_options=self._storage_options, + **self._save_args, + ) + + def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: + self._fs.mkdirs(save_path, exist_ok=True) + ext = self.EXTENSION + saver = f"to_{self.BUILDER}" + for split, split_ds in data.items(): + split_path = f"{save_path}/{split}{ext}" + getattr(split_ds, saver)( + split_path, + storage_options=self._storage_options, + **self._save_args, + ) + + def _describe(self) -> dict[str, Any]: + return { + "path": self._filepath, + "file_format": self.BUILDER, + "protocol": self._protocol, + "version": self._version, + "load_args": self._load_args, + "save_args": self._save_args, + } + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + path = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(path) diff --git a/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py b/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py new file mode 100644 index 000000000..c040b38e0 --- /dev/null +++ b/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, ClassVar + +from datasets import Dataset, DatasetDict, load_from_disk +from kedro.io.core import DatasetError, get_filepath_str + +from ._base import DatasetLike, FilesystemDataset + + +class ArrowDataset(FilesystemDataset): + """``ArrowDataset`` loads/saves Hugging Face ``Dataset`` and + ``DatasetDict`` objects to/from disk in + `Arrow `_ format + using ``save_to_disk`` / ``load_from_disk``. + + Iterable variants (``IterableDataset``, ``IterableDatasetDict``) + are materialised before saving. + + Examples: + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + reviews: + type: huggingface.ArrowDataset + path: data/01_raw/reviews + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + + >>> from datasets import Dataset + >>> from kedro_datasets.huggingface.arrow_dataset import ( + ... ArrowDataset, + ... ) + >>> + >>> data = Dataset.from_dict( + ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + ... ) + >>> + >>> dataset = ArrowDataset( + ... path=tmp_path / "test_hf_dataset" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert reloaded.to_dict() == data.to_dict() + """ + + BUILDER: ClassVar[str] = "arrow" + EXTENSION: ClassVar[str] = ".arrow" + + def _load_dataset(self, load_path: str) -> DatasetLike: + return load_from_disk( + load_path, + storage_options=self._storage_options, + **self._load_args, + ) + + def _save_dataset(self, data: Dataset, save_path: str) -> None: + data.save_to_disk( + save_path, + storage_options=self._storage_options, + **self._save_args, + ) + + def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: + data.save_to_disk( + save_path, + storage_options=self._storage_options, + **self._save_args, + ) + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + + return self._fs.isdir(load_path) and ( + self._fs.exists(f"{load_path}/dataset_dict.json") + or self._fs.exists(f"{load_path}/dataset_info.json") + ) diff --git a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py new file mode 100644 index 000000000..fc77c2725 --- /dev/null +++ b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import ClassVar + +from ._base import FilesystemDataset + + +class CSVDataset(FilesystemDataset): + """``CSVDataset`` loads/saves Hugging Face ``Dataset`` and + ``DatasetDict`` objects to/from CSV files. + + Iterable variants (``IterableDataset``, ``IterableDatasetDict``) + are materialised before saving. + + Examples: + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + reviews: + type: huggingface.CSVDataset + path: data/01_raw/reviews.csv + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + + >>> from datasets import Dataset + >>> from kedro_datasets.huggingface.csv_dataset import ( + ... CSVDataset, + ... ) + >>> + >>> data = Dataset.from_dict( + ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + ... ) + >>> + >>> dataset = CSVDataset( + ... path=tmp_path / "test_hf_dataset.csv" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert reloaded.to_dict() == data.to_dict() + """ + + BUILDER: ClassVar[str] = "csv" + EXTENSION: ClassVar[str] = ".csv" diff --git a/kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py new file mode 100644 index 000000000..fb5b3efad --- /dev/null +++ b/kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import ClassVar + +from datasets import Dataset, DatasetDict +from kedro.io.core import DatasetError + +from ._base import FilesystemDataset + + +class HDF5Dataset(FilesystemDataset): + """``HDF5Dataset`` loads Hugging Face ``Dataset`` and + ``DatasetDict`` objects from + `HDF5 `_ files. + + Saving is **not** supported because the ``datasets`` library does + not provide a ``to_hdf5`` method. + + Examples: + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + reviews: + type: huggingface.HDF5Dataset + path: data/01_raw/reviews.h5 + ``` + """ + + BUILDER: ClassVar[str] = "hdf5" + EXTENSION: ClassVar[str] = ".h5" + + def _save_dataset(self, data: Dataset, save_path: str) -> None: + msg = "Saving in hdf5 format is not supported by the Hugging Face datasets library." + raise DatasetError(msg) + + def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: + msg = "Saving in hdf5 format is not supported by the Hugging Face datasets library." + raise DatasetError(msg) diff --git a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py index b56e55b2c..f7796830e 100644 --- a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py @@ -1,28 +1,16 @@ from __future__ import annotations -import os -from copy import deepcopy -from pathlib import PurePosixPath from typing import Any, TypeAlias -import fsspec from datasets import ( Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset, - load_from_disk, ) from huggingface_hub import HfApi from kedro.io import AbstractDataset -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - get_filepath_str, - get_protocol_and_path, -) DatasetLike: TypeAlias = Dataset | DatasetDict | IterableDataset | IterableDatasetDict @@ -85,255 +73,3 @@ def _describe(self) -> dict[str, Any]: def list_datasets(): api = HfApi() return list(api.list_datasets()) - - -class LocalHFDataset(AbstractVersionedDataset[DatasetLike, DatasetLike]): - """``LocalHFDataset`` loads/saves Hugging Face ``Dataset``, - ``DatasetDict``, ``IterableDataset``, and - ``IterableDatasetDict`` objects to/from disk using an - underlying filesystem (e.g.: local, S3, GCS). Iterable - variants are materialized before saving. - - Examples: - Using the - [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) - with a ``datasets.Dataset`` - - ```yaml - reviews: - type: huggingface.LocalHFDataset - path: data/01_raw/reviews.arrow - ``` - - By default, data will be loaded and saved from - [Arrow](https://huggingface.co/docs/datasets/about_arrow) format. - - Using the - [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) - with a ``datasets.DatasetDict`` in JSON format: - - ```yaml - review_dict: - type: huggingface.LocalHFDataset - path: data/01_raw/review_dict/ - file_format: json - ``` - - This saves each individual ``datasets.Dataset`` into separate files - in the directory in JSON format. - - The ``file_format`` accepts `arrow`, `parquet`, `json`, `csv`, `lance`, - and `hdf5` as arguments. - - For more on saving and loading from a filesystem with the Datasets - library, see - [here](https://huggingface.co/docs/datasets/v4.8.4/en/loading#local-and-remote-files). - - Using the - [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) - with a ``datasets.Dataset``: - - >>> from datasets import Dataset - >>> from kedro_datasets.huggingface.hugging_face_dataset import ( - ... LocalHFDataset, - ... ) - >>> - >>> data = Dataset.from_dict( - ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} - ... ) - >>> - >>> dataset = LocalHFDataset( - ... path=tmp_path / "test_hf_dataset.arrow" - ... ) - >>> dataset.save(data) - >>> reloaded = dataset.load() - >>> assert reloaded.to_dict() == data.to_dict() - - Using the - [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) - with a ``datasets.DatasetDict``: - - >>> from datasets import Dataset, DatasetDict - >>> from kedro_datasets.huggingface.hugging_face_dataset import ( - ... LocalHFDataset, - ... ) - >>> - >>> data = DatasetDict({ - ... "train": Dataset.from_dict( - ... {"col1": [1, 2], "col2": ["a", "b"]} - ... ), - ... "test": Dataset.from_dict( - ... {"col1": [3], "col2": ["c"]} - ... ), - ... }) - >>> - >>> dataset = LocalHFDataset( - ... path=tmp_path / "test_hf_dataset_dict" - ... ) - >>> dataset.save(data) - >>> reloaded = dataset.load() - >>> assert list(reloaded.keys()) == ["train", "test"] - - """ - - _SUPPORTED_FORMATS = {"arrow", "parquet", "json", "csv", "lance", "hdf5"} - _FORMAT_EXTENSIONS = { - "arrow": ".arrow", - "parquet": ".parquet", - "json": ".json", - "csv": ".csv", - "lance": ".lance", - "hdf5": ".h5", - } - - def __init__( # noqa: PLR0913 - self, - *, - path: str | os.PathLike, - file_format: str = "arrow", - version: Version | None = None, - load_args: dict[str, Any] | None = None, - save_args: dict[str, Any] | None = None, - credentials: dict[str, Any] | None = None, - fs_args: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> None: - if file_format not in self._SUPPORTED_FORMATS: - msg = ( - f"Unsupported file_format '{file_format}'. " - f"Must be one of {sorted(self._SUPPORTED_FORMATS)}." - ) - raise DatasetError(msg) - - self._file_format = file_format - _fs_args = deepcopy(fs_args) or {} - _credentials = deepcopy(credentials) or {} - - protocol, resolved_path = get_protocol_and_path(path, version) - self._protocol = protocol - - if protocol == "file": - _fs_args.setdefault("auto_mkdir", True) - - self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - - self._load_args = load_args or {} - self._save_args = save_args or {} - self.metadata = metadata - - # storage_options passed to HF's load/save methods - self._storage_options = {**_credentials, **_fs_args} or None - - super().__init__( - filepath=PurePosixPath(resolved_path), - version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, - ) - - def _load(self) -> DatasetLike: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - - if self._file_format == "arrow": - return load_from_disk( - load_path, - storage_options=self._storage_options, - **self._load_args, - ) - - ext = self._FORMAT_EXTENSIONS[self._file_format] - loader = getattr(Dataset, f"from_{self._file_format}") - - if self._fs.isdir(load_path): - paths = { - PurePosixPath(p).stem: p for p in self._fs.glob(f"{load_path}/*{ext}") - } - return DatasetDict( - { - split: loader(path, **self._load_args) - for split, path in paths.items() - } - ) - - return loader(load_path, **self._load_args) - - def _save(self, data: DatasetLike) -> None: - if not isinstance( - data, - Dataset | DatasetDict | IterableDataset | IterableDatasetDict, - ): - msg = ( - "LocalHFDataset only supports `datasets.Dataset`, " - "`datasets.DatasetDict`, " - "`datasets.IterableDataset`, and " - "`datasets.IterableDatasetDict` instances. " - f"Got {type(data)}" - ) - raise DatasetError(msg) - - if isinstance(data, IterableDatasetDict): - data = DatasetDict({k: Dataset.from_list(list(v)) for k, v in data.items()}) - elif isinstance(data, IterableDataset): - data = Dataset.from_list(list(data)) - - save_path = get_filepath_str(self._get_save_path(), self._protocol) - - if self._file_format == "arrow": - data.save_to_disk( - save_path, - storage_options=self._storage_options, - **self._save_args, - ) - elif isinstance(data, DatasetDict): - self._fs.mkdirs(save_path, exist_ok=True) - ext = self._FORMAT_EXTENSIONS[self._file_format] - saver = f"to_{self._file_format}" - for split, split_ds in data.items(): - split_path = f"{save_path}/{split}{ext}" - getattr(split_ds, saver)( - split_path, - storage_options=self._storage_options, - **self._save_args, - ) - else: - saver = f"to_{self._file_format}" - getattr(data, saver)( - save_path, - storage_options=self._storage_options, - **self._save_args, - ) - - self._invalidate_cache() - - def _describe(self) -> dict[str, Any]: - return { - "path": self._filepath, - "file_format": self._file_format, - "protocol": self._protocol, - "version": self._version, - "load_args": self._load_args, - "save_args": self._save_args, - } - - def _exists(self) -> bool: - try: - load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DatasetError: - return False - - if self._file_format == "arrow": - return self._fs.isdir(load_path) and ( - self._fs.exists(f"{load_path}/dataset_dict.json") - or self._fs.exists(f"{load_path}/dataset_info.json") - ) - - return self._fs.exists(load_path) - - def _release(self) -> None: - super()._release() - self._invalidate_cache() - - def _invalidate_cache(self) -> None: - """Invalidate underlying filesystem caches.""" - path = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(path) diff --git a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py new file mode 100644 index 000000000..c3782d58a --- /dev/null +++ b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import ClassVar + +from ._base import FilesystemDataset + + +class JSONDataset(FilesystemDataset): + """``JSONDataset`` loads/saves Hugging Face ``Dataset`` and + ``DatasetDict`` objects to/from JSON files. + + Iterable variants (``IterableDataset``, ``IterableDatasetDict``) + are materialised before saving. + + Examples: + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + reviews: + type: huggingface.JSONDataset + path: data/01_raw/reviews.json + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + + >>> from datasets import Dataset + >>> from kedro_datasets.huggingface.json_dataset import ( + ... JSONDataset, + ... ) + >>> + >>> data = Dataset.from_dict( + ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + ... ) + >>> + >>> dataset = JSONDataset( + ... path=tmp_path / "test_hf_dataset.json" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert reloaded.to_dict() == data.to_dict() + """ + + BUILDER: ClassVar[str] = "json" + EXTENSION: ClassVar[str] = ".json" diff --git a/kedro-datasets/kedro_datasets/huggingface/lance_dataset.py b/kedro-datasets/kedro_datasets/huggingface/lance_dataset.py new file mode 100644 index 000000000..c8cb6c139 --- /dev/null +++ b/kedro-datasets/kedro_datasets/huggingface/lance_dataset.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import ClassVar + +from datasets import Dataset, DatasetDict +from kedro.io.core import DatasetError + +from ._base import FilesystemDataset + + +class LanceDataset(FilesystemDataset): + """``LanceDataset`` loads Hugging Face ``Dataset`` and + ``DatasetDict`` objects from `Lance `_ + files. + + Saving is **not** supported because the ``datasets`` library does + not provide a ``to_lance`` method. + + Examples: + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + reviews: + type: huggingface.LanceDataset + path: data/01_raw/reviews.lance + ``` + """ + + BUILDER: ClassVar[str] = "lance" + EXTENSION: ClassVar[str] = ".lance" + + def _save_dataset(self, data: Dataset, save_path: str) -> None: + msg = "Saving in lance format is not supported by the Hugging Face datasets library." + raise DatasetError(msg) + + def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: + msg = "Saving in lance format is not supported by the Hugging Face datasets library." + raise DatasetError(msg) diff --git a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py new file mode 100644 index 000000000..9a7cf8921 --- /dev/null +++ b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import ClassVar + +from ._base import FilesystemDataset + + +class ParquetDataset(FilesystemDataset): + """``ParquetDataset`` loads/saves Hugging Face ``Dataset`` and + ``DatasetDict`` objects to/from + `Parquet `_ files. + + Iterable variants (``IterableDataset``, ``IterableDatasetDict``) + are materialised before saving. + + Examples: + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + reviews: + type: huggingface.ParquetDataset + path: data/01_raw/reviews.parquet + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + + >>> from datasets import Dataset + >>> from kedro_datasets.huggingface.parquet_dataset import ( + ... ParquetDataset, + ... ) + >>> + >>> data = Dataset.from_dict( + ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + ... ) + >>> + >>> dataset = ParquetDataset( + ... path=tmp_path / "test_hf_dataset.parquet" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert reloaded.to_dict() == data.to_dict() + """ + + BUILDER: ClassVar[str] = "parquet" + EXTENSION: ClassVar[str] = ".parquet" diff --git a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json index 315ef6629..519692e5c 100644 --- a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json +++ b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json @@ -17,9 +17,14 @@ "email.EmailMessageDataset", "geopandas.GenericDataset", "holoviews.HoloviewsWriter", + "huggingface.ArrowDataset", + "huggingface.CSVDataset", + "huggingface.HDF5Dataset", "huggingface.HFDataset", "huggingface.HFTransformerPipelineDataset", - "huggingface.LocalHFDataset", + "huggingface.JSONDataset", + "huggingface.LanceDataset", + "huggingface.ParquetDataset", "ibis.FileDataset", "ibis.TableDataset", "json.JSONDataset", @@ -456,7 +461,14 @@ "if": { "properties": { "type": { - "const": "huggingface.LocalHFDataset" + "enum": [ + "huggingface.ArrowDataset", + "huggingface.ParquetDataset", + "huggingface.JSONDataset", + "huggingface.CSVDataset", + "huggingface.LanceDataset", + "huggingface.HDF5Dataset" + ] } } }, @@ -467,12 +479,6 @@ "type": "string", "description": "Path to a directory or file for persisting Hugging Face datasets. Supports local and remote filesystems (e.g. s3://)." }, - "file_format": { - "type": "string", - "enum": ["arrow", "parquet", "json", "csv", "lance", "hdf5"], - "default": "arrow", - "description": "The file format to use for saving and loading. Defaults to 'arrow'." - }, "load_args": { "type": "object", "description": "Additional arguments passed to the load method." diff --git a/kedro-datasets/tests/huggingface/conftest.py b/kedro-datasets/tests/huggingface/conftest.py index 8630b0dbb..694f767b7 100644 --- a/kedro-datasets/tests/huggingface/conftest.py +++ b/kedro-datasets/tests/huggingface/conftest.py @@ -4,3 +4,42 @@ discover them automatically. More info here: https://docs.pytest.org/en/latest/fixture.html """ + +import pytest +from datasets import Dataset, DatasetDict, IterableDatasetDict + + +@pytest.fixture +def dataset(): + return Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + + +@pytest.fixture +def dataset_dict(): + return DatasetDict( + { + "train": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), + "test": Dataset.from_dict({"col1": [3], "col2": ["c"]}), + } + ) + + +@pytest.fixture +def iterable_dataset(): + return Dataset.from_dict( + {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + ).to_iterable_dataset() + + +@pytest.fixture +def iterable_dataset_dict(): + return IterableDatasetDict( + { + "train": Dataset.from_dict( + {"col1": [1, 2], "col2": ["a", "b"]} + ).to_iterable_dataset(), + "test": Dataset.from_dict( + {"col1": [3], "col2": ["c"]} + ).to_iterable_dataset(), + } + ) diff --git a/kedro-datasets/tests/huggingface/test_arrow_dataset.py b/kedro-datasets/tests/huggingface/test_arrow_dataset.py new file mode 100644 index 000000000..4198e86b8 --- /dev/null +++ b/kedro-datasets/tests/huggingface/test_arrow_dataset.py @@ -0,0 +1,256 @@ +from pathlib import PurePosixPath + +import pytest +from datasets import Dataset, DatasetDict +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets.huggingface.arrow_dataset import ArrowDataset + + +@pytest.fixture +def path_arrow(tmp_path): + return (tmp_path / "test_hf_dataset").as_posix() + + +@pytest.fixture +def arrow_dataset(path_arrow, save_args, load_args, fs_args): + return ArrowDataset( + path=path_arrow, + save_args=save_args, + load_args=load_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def versioned_arrow_dataset(path_arrow, load_version, save_version): + return ArrowDataset(path=path_arrow, version=Version(load_version, save_version)) + + +class TestArrowDataset: + def test_save_and_load_dataset(self, arrow_dataset, dataset): + """Test saving and reloading a Dataset.""" + arrow_dataset.save(dataset) + reloaded = arrow_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, arrow_dataset, dataset_dict): + """Test saving and reloading a DatasetDict.""" + arrow_dataset.save(dataset_dict) + reloaded = arrow_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_exists(self, arrow_dataset, dataset): + """Test `exists` method for both existing and nonexistent dataset.""" + assert not arrow_dataset.exists() + arrow_dataset.save(dataset) + assert arrow_dataset.exists() + + def test_exists_dataset_dict(self, arrow_dataset, dataset_dict): + """Test `exists` method for DatasetDict (checks dataset_dict.json marker).""" + assert not arrow_dataset.exists() + arrow_dataset.save(dataset_dict) + assert arrow_dataset.exists() + + def test_save_and_load_iterable_dataset(self, arrow_dataset, iterable_dataset): + """Test saving an IterableDataset materializes and round-trips.""" + arrow_dataset.save(iterable_dataset) + reloaded = arrow_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, arrow_dataset, iterable_dataset_dict + ): + """Test saving an IterableDatasetDict materializes and round-trips.""" + arrow_dataset.save(iterable_dataset_dict) + reloaded = arrow_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + assert reloaded["train"].to_dict() == { + "col1": [1, 2], + "col2": ["a", "b"], + } + + @pytest.mark.parametrize("save_args", [{"num_shards": 2}], indirect=True) + def test_save_extra_params(self, arrow_dataset, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert arrow_dataset._save_args[key] == value + + @pytest.mark.parametrize("load_args", [{"keep_in_memory": True}], indirect=True) + def test_load_extra_params(self, arrow_dataset, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert arrow_dataset._load_args[key] == value + + def test_load_missing_dataset(self, arrow_dataset): + """Check the error when trying to load missing dataset.""" + pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.arrow_dataset.ArrowDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + arrow_dataset.load() + + def test_save_invalid_type(self, arrow_dataset): + """Check the error when saving an unsupported type.""" + pattern = r"ArrowDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." + with pytest.raises(DatasetError, match=pattern): + arrow_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/hf_data", S3FileSystem), + ("file:///tmp/hf_data", LocalFileSystem), + ("/tmp/hf_data", LocalFileSystem), + ("gcs://bucket/hf_data", GCSFileSystem), + ("https://example.com/hf_data", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + dataset = ArrowDataset(path=path) + assert isinstance(dataset._fs, instance_type) + + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(dataset._filepath) == resolved + assert isinstance(dataset._filepath, PurePosixPath) + + def test_pathlike_path(self, tmp_path, dataset): + """Test that os.PathLike paths are supported.""" + path = tmp_path / "test_hf_pathlike" + ds = ArrowDataset(path=path) + ds.save(dataset) + reloaded = ds.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = "test_hf" + dataset = ArrowDataset(path=path) + dataset.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + +class TestArrowDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + path = "test_hf_dataset" + ds = ArrowDataset(path=path) + ds_versioned = ArrowDataset( + path=path, version=Version(load_version, save_version) + ) + assert path in str(ds) + assert "version" not in str(ds) + + assert path in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "ArrowDataset" in str(ds_versioned) + assert "ArrowDataset" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + + def test_save_and_load(self, versioned_arrow_dataset, dataset): + """Test that saved and reloaded data matches the original one for + the versioned dataset.""" + versioned_arrow_dataset.save(dataset) + reloaded = versioned_arrow_dataset.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, versioned_arrow_dataset, dataset_dict): + """Test versioned save and reload with DatasetDict.""" + versioned_arrow_dataset.save(dataset_dict) + reloaded = versioned_arrow_dataset.load() + assert isinstance(reloaded, DatasetDict) + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_save_and_load_iterable_dataset( + self, versioned_arrow_dataset, iterable_dataset + ): + """Test versioned save with IterableDataset.""" + versioned_arrow_dataset.save(iterable_dataset) + reloaded = versioned_arrow_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, versioned_arrow_dataset, iterable_dataset_dict + ): + """Test versioned save with IterableDatasetDict.""" + versioned_arrow_dataset.save(iterable_dataset_dict) + reloaded = versioned_arrow_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + + def test_no_versions(self, versioned_arrow_dataset): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for kedro_datasets.huggingface.arrow_dataset.ArrowDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_arrow_dataset.load() + + def test_exists(self, versioned_arrow_dataset, dataset): + """Test `exists` method invocation for versioned dataset.""" + assert not versioned_arrow_dataset.exists() + versioned_arrow_dataset.save(dataset) + assert versioned_arrow_dataset.exists() + + def test_prevent_overwrite(self, versioned_arrow_dataset, dataset): + """Check the error when attempting to override the dataset if the + corresponding version already exists.""" + versioned_arrow_dataset.save(dataset) + pattern = ( + r"Save path \'.+\' for kedro_datasets.huggingface.arrow_dataset.ArrowDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_arrow_dataset.save(dataset) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_arrow_dataset, load_version, save_version, dataset + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + f"Save version '{save_version}' did not match " + f"load version '{load_version}' for " + r"kedro_datasets.huggingface.arrow_dataset.ArrowDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_arrow_dataset.save(dataset) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + + with pytest.raises(DatasetError, match=pattern): + ArrowDataset( + path="https://example.com/hf_data", + version=Version(None, None), + ) + + def test_save_invalid_type_versioned(self, versioned_arrow_dataset): + """Check the error when saving an unsupported type through versioned dataset.""" + pattern = r"ArrowDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." + with pytest.raises(DatasetError, match=pattern): + versioned_arrow_dataset.save("not a dataset") diff --git a/kedro-datasets/tests/huggingface/test_csv_dataset.py b/kedro-datasets/tests/huggingface/test_csv_dataset.py new file mode 100644 index 000000000..700b7307e --- /dev/null +++ b/kedro-datasets/tests/huggingface/test_csv_dataset.py @@ -0,0 +1,182 @@ +from pathlib import PurePosixPath + +import pytest +from datasets import Dataset, DatasetDict +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets.huggingface.csv_dataset import CSVDataset + + +@pytest.fixture +def path_csv(tmp_path): + return (tmp_path / "test.csv").as_posix() + + +@pytest.fixture +def path_csv_dir(tmp_path): + return (tmp_path / "test_csv_dd").as_posix() + + +@pytest.fixture +def csv_dataset(path_csv, save_args, load_args, fs_args): + return CSVDataset( + path=path_csv, + save_args=save_args, + load_args=load_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def csv_dataset_dir(path_csv_dir): + return CSVDataset(path=path_csv_dir) + + +@pytest.fixture +def versioned_csv_dataset(path_csv, load_version, save_version): + return CSVDataset(path=path_csv, version=Version(load_version, save_version)) + + +class TestCSVDataset: + def test_save_and_load_dataset(self, csv_dataset, dataset): + """A single-file load returns a Dataset (auto-unwrapped).""" + csv_dataset.save(dataset) + reloaded = csv_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_with_split(self, path_csv, dataset): + """With split in load_args, the explicit split is respected.""" + ds = CSVDataset(path=path_csv, load_args={"split": "train"}) + ds.save(dataset) + reloaded = ds.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, csv_dataset_dir, dataset_dict): + csv_dataset_dir.save(dataset_dict) + reloaded = csv_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_save_and_load_iterable_dataset(self, csv_dataset, iterable_dataset): + csv_dataset.save(iterable_dataset) + reloaded = csv_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, csv_dataset_dir, iterable_dataset_dict + ): + csv_dataset_dir.save(iterable_dataset_dict) + reloaded = csv_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + assert reloaded["train"].to_dict() == { + "col1": [1, 2], + "col2": ["a", "b"], + } + + def test_exists(self, csv_dataset, dataset): + assert not csv_dataset.exists() + csv_dataset.save(dataset) + assert csv_dataset.exists() + + def test_load_missing_dataset(self, csv_dataset): + pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.csv_dataset.CSVDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + csv_dataset.load() + + def test_save_invalid_type(self, csv_dataset): + pattern = r"CSVDataset only supports" + with pytest.raises(DatasetError, match=pattern): + csv_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/data.csv", S3FileSystem), + ("file:///tmp/data.csv", LocalFileSystem), + ("/tmp/data.csv", LocalFileSystem), + ("gcs://bucket/data.csv", GCSFileSystem), + ("https://example.com/data.csv", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + ds = CSVDataset(path=path) + assert isinstance(ds._fs, instance_type) + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + assert str(ds._filepath) == resolved + assert isinstance(ds._filepath, PurePosixPath) + + def test_pathlike_path(self, tmp_path, dataset): + path = tmp_path / "test_pathlike.csv" + ds = CSVDataset(path=path) + ds.save(dataset) + reloaded = ds.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = "test.csv" + ds = CSVDataset(path=path) + ds.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + +class TestCSVDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + path = "test.csv" + ds = CSVDataset(path=path) + ds_versioned = CSVDataset( + path=path, version=Version(load_version, save_version) + ) + assert path in str(ds) + assert "version" not in str(ds) + + assert path in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "CSVDataset" in str(ds_versioned) + + def test_save_and_load(self, versioned_csv_dataset, dataset): + versioned_csv_dataset.save(dataset) + reloaded = versioned_csv_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_no_versions(self, versioned_csv_dataset): + pattern = r"Did not find any versions for kedro_datasets.huggingface.csv_dataset.CSVDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.load() + + def test_exists(self, versioned_csv_dataset, dataset): + assert not versioned_csv_dataset.exists() + versioned_csv_dataset.save(dataset) + assert versioned_csv_dataset.exists() + + def test_prevent_overwrite(self, versioned_csv_dataset, dataset): + versioned_csv_dataset.save(dataset) + pattern = ( + r"Save path \'.+\' for kedro_datasets.huggingface.csv_dataset.CSVDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dataset) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + with pytest.raises(DatasetError, match=pattern): + CSVDataset( + path="https://example.com/data.csv", + version=Version(None, None), + ) diff --git a/kedro-datasets/tests/huggingface/test_hdf5_dataset.py b/kedro-datasets/tests/huggingface/test_hdf5_dataset.py new file mode 100644 index 000000000..a83e5367a --- /dev/null +++ b/kedro-datasets/tests/huggingface/test_hdf5_dataset.py @@ -0,0 +1,74 @@ +from pathlib import PurePosixPath + +import pytest +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets.huggingface.hdf5_dataset import HDF5Dataset + + +@pytest.fixture +def path_hdf5(tmp_path): + return (tmp_path / "test.h5").as_posix() + + +@pytest.fixture +def hdf5_dataset(path_hdf5): + return HDF5Dataset(path=path_hdf5) + + +class TestHDF5Dataset: + def test_save_dataset_raises(self, hdf5_dataset, dataset): + with pytest.raises( + DatasetError, match="Saving in hdf5 format is not supported" + ): + hdf5_dataset.save(dataset) + + def test_save_dataset_dict_raises(self, hdf5_dataset, dataset_dict): + with pytest.raises( + DatasetError, match="Saving in hdf5 format is not supported" + ): + hdf5_dataset.save(dataset_dict) + + def test_save_invalid_type(self, hdf5_dataset): + pattern = r"HDF5Dataset only supports" + with pytest.raises(DatasetError, match=pattern): + hdf5_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/data.h5", S3FileSystem), + ("file:///tmp/data.h5", LocalFileSystem), + ("/tmp/data.h5", LocalFileSystem), + ("gcs://bucket/data.h5", GCSFileSystem), + ("https://example.com/data.h5", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + ds = HDF5Dataset(path=path) + assert isinstance(ds._fs, instance_type) + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + assert str(ds._filepath) == resolved + assert isinstance(ds._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = "test.h5" + ds = HDF5Dataset(path=path) + ds.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + def test_exists_when_missing(self, hdf5_dataset): + assert not hdf5_dataset.exists() + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + with pytest.raises(DatasetError, match=pattern): + HDF5Dataset( + path="https://example.com/data.h5", + version=Version(None, None), + ) diff --git a/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py b/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py index 66babe0a9..909362ec2 100644 --- a/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py +++ b/kedro-datasets/tests/huggingface/test_hugging_face_dataset.py @@ -1,16 +1,7 @@ -from pathlib import PurePosixPath - import pytest -from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem from huggingface_hub import HfApi -from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version -from s3fs.core import S3FileSystem from kedro_datasets.huggingface import HFDataset -from kedro_datasets.huggingface.hugging_face_dataset import LocalHFDataset @pytest.fixture @@ -40,332 +31,3 @@ def test_list_datasets(self, mocker): datasets = HFDataset.list_datasets() assert datasets == expected_datasets - - -@pytest.fixture -def path_local_hf(tmp_path): - return (tmp_path / "test_hf_dataset").as_posix() - - -@pytest.fixture -def local_hf_dataset(path_local_hf, save_args, load_args, fs_args): - return LocalHFDataset( - path=path_local_hf, - save_args=save_args, - load_args=load_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def versioned_local_hf_dataset(path_local_hf, load_version, save_version): - return LocalHFDataset( - path=path_local_hf, version=Version(load_version, save_version) - ) - - -@pytest.fixture -def dataset(): - return Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) - - -@pytest.fixture -def dataset_dict(): - return DatasetDict( - { - "train": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), - "test": Dataset.from_dict({"col1": [3], "col2": ["c"]}), - } - ) - - -@pytest.fixture -def iterable_dataset(): - return Dataset.from_dict( - {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} - ).to_iterable_dataset() - - -@pytest.fixture -def iterable_dataset_dict(): - return IterableDatasetDict( - { - "train": Dataset.from_dict( - {"col1": [1, 2], "col2": ["a", "b"]} - ).to_iterable_dataset(), - "test": Dataset.from_dict( - {"col1": [3], "col2": ["c"]} - ).to_iterable_dataset(), - } - ) - - -@pytest.fixture -def parquet_local_hf_dataset(path_local_hf): - return LocalHFDataset(path=path_local_hf, file_format="parquet") - - -class TestLocalHFDataset: - def test_save_and_load_dataset(self, local_hf_dataset, dataset): - """Test saving and reloading a Dataset.""" - local_hf_dataset.save(dataset) - reloaded = local_hf_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_dict(self, local_hf_dataset, dataset_dict): - """Test saving and reloading a DatasetDict.""" - local_hf_dataset.save(dataset_dict) - reloaded = local_hf_dataset.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - - def test_exists(self, local_hf_dataset, dataset): - """Test `exists` method for both existing and nonexistent dataset.""" - assert not local_hf_dataset.exists() - local_hf_dataset.save(dataset) - assert local_hf_dataset.exists() - - def test_exists_dataset_dict(self, local_hf_dataset, dataset_dict): - """Test `exists` method for DatasetDict (checks dataset_dict.json marker).""" - assert not local_hf_dataset.exists() - local_hf_dataset.save(dataset_dict) - assert local_hf_dataset.exists() - - def test_save_and_load_iterable_dataset(self, local_hf_dataset, iterable_dataset): - """Test saving an IterableDataset materializes and round-trips.""" - local_hf_dataset.save(iterable_dataset) - reloaded = local_hf_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } - - def test_save_and_load_iterable_dataset_dict( - self, local_hf_dataset, iterable_dataset_dict - ): - """Test saving an IterableDatasetDict materializes and round-trips.""" - local_hf_dataset.save(iterable_dataset_dict) - reloaded = local_hf_dataset.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - assert reloaded["train"].to_dict() == { - "col1": [1, 2], - "col2": ["a", "b"], - } - - def test_save_and_load_parquet(self, parquet_local_hf_dataset, dataset): - """Test saving and reloading a Dataset as parquet.""" - parquet_local_hf_dataset.save(dataset) - reloaded = parquet_local_hf_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_parquet_dataset_dict( - self, parquet_local_hf_dataset, dataset_dict - ): - """Test saving and reloading a DatasetDict as parquet.""" - parquet_local_hf_dataset.save(dataset_dict) - reloaded = parquet_local_hf_dataset.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - - def test_exists_parquet(self, parquet_local_hf_dataset, dataset): - """Test `exists` for parquet format.""" - assert not parquet_local_hf_dataset.exists() - parquet_local_hf_dataset.save(dataset) - assert parquet_local_hf_dataset.exists() - - def test_save_and_load_json_dataset_dict(self, tmp_path, dataset_dict): - """Test saving and reloading a DatasetDict as JSON.""" - path = (tmp_path / "test_json_dd").as_posix() - ds = LocalHFDataset(path=path, file_format="json") - ds.save(dataset_dict) - reloaded = ds.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - - def test_invalid_file_format(self): - """Test that an unsupported file_format raises DatasetError.""" - pattern = r"Unsupported file_format" - with pytest.raises(DatasetError, match=pattern): - LocalHFDataset(path="test", file_format="xml") - - @pytest.mark.parametrize("save_args", [{"num_shards": 2}], indirect=True) - def test_save_extra_params(self, local_hf_dataset, save_args): - """Test overriding the default save arguments.""" - for key, value in save_args.items(): - assert local_hf_dataset._save_args[key] == value - - @pytest.mark.parametrize("load_args", [{"keep_in_memory": True}], indirect=True) - def test_load_extra_params(self, local_hf_dataset, load_args): - """Test overriding the default load arguments.""" - for key, value in load_args.items(): - assert local_hf_dataset._load_args[key] == value - - def test_load_missing_dataset(self, local_hf_dataset): - """Check the error when trying to load missing dataset.""" - pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.*\)" - with pytest.raises(DatasetError, match=pattern): - local_hf_dataset.load() - - def test_save_invalid_type(self, local_hf_dataset): - """Check the error when saving an unsupported type.""" - pattern = r"LocalHFDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." - with pytest.raises(DatasetError, match=pattern): - local_hf_dataset.save({"not": "a dataset"}) - - @pytest.mark.parametrize( - "path,instance_type", - [ - ("s3://bucket/hf_data", S3FileSystem), - ("file:///tmp/hf_data", LocalFileSystem), - ("/tmp/hf_data", LocalFileSystem), - ("gcs://bucket/hf_data", GCSFileSystem), - ("https://example.com/hf_data", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, path, instance_type): - dataset = LocalHFDataset(path=path) - assert isinstance(dataset._fs, instance_type) - - resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] - - assert str(dataset._filepath) == resolved - assert isinstance(dataset._filepath, PurePosixPath) - - def test_pathlike_path(self, tmp_path, dataset): - """Test that os.PathLike paths are supported.""" - path = tmp_path / "test_hf_pathlike" - ds = LocalHFDataset(path=path) - ds.save(dataset) - reloaded = ds.load() - assert reloaded.to_dict() == dataset.to_dict() - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - path = "test_hf" - dataset = LocalHFDataset(path=path) - dataset.release() - fs_mock.invalidate_cache.assert_called_once_with(path) - - -class TestLocalHFDatasetVersioned: - def test_version_str_repr(self, load_version, save_version): - """Test that version is in string representation of the class instance - when applicable.""" - path = "test_hf_dataset" - ds = LocalHFDataset(path=path) - ds_versioned = LocalHFDataset( - path=path, version=Version(load_version, save_version) - ) - assert path in str(ds) - assert "version" not in str(ds) - - assert path in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "LocalHFDataset" in str(ds_versioned) - assert "LocalHFDataset" in str(ds) - assert "protocol" in str(ds_versioned) - assert "protocol" in str(ds) - - def test_save_and_load(self, versioned_local_hf_dataset, dataset): - """Test that saved and reloaded data matches the original one for - the versioned dataset.""" - versioned_local_hf_dataset.save(dataset) - reloaded = versioned_local_hf_dataset.load() - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_dict(self, versioned_local_hf_dataset, dataset_dict): - """Test versioned save and reload with DatasetDict.""" - versioned_local_hf_dataset.save(dataset_dict) - reloaded = versioned_local_hf_dataset.load() - assert isinstance(reloaded, DatasetDict) - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - - def test_save_and_load_iterable_dataset( - self, versioned_local_hf_dataset, iterable_dataset - ): - """Test versioned save with IterableDataset.""" - versioned_local_hf_dataset.save(iterable_dataset) - reloaded = versioned_local_hf_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } - - def test_save_and_load_iterable_dataset_dict( - self, versioned_local_hf_dataset, iterable_dataset_dict - ): - """Test versioned save with IterableDatasetDict.""" - versioned_local_hf_dataset.save(iterable_dataset_dict) - reloaded = versioned_local_hf_dataset.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - - def test_no_versions(self, versioned_local_hf_dataset): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_local_hf_dataset.load() - - def test_exists(self, versioned_local_hf_dataset, dataset): - """Test `exists` method invocation for versioned dataset.""" - assert not versioned_local_hf_dataset.exists() - versioned_local_hf_dataset.save(dataset) - assert versioned_local_hf_dataset.exists() - - def test_prevent_overwrite(self, versioned_local_hf_dataset, dataset): - """Check the error when attempting to override the dataset if the - corresponding version already exists.""" - versioned_local_hf_dataset.save(dataset) - pattern = ( - r"Save path \'.+\' for kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_local_hf_dataset.save(dataset) - - @pytest.mark.parametrize( - "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True - ) - @pytest.mark.parametrize( - "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True - ) - def test_save_version_warning( - self, versioned_local_hf_dataset, load_version, save_version, dataset - ): - """Check the warning when saving to the path that differs from - the subsequent load path.""" - pattern = ( - f"Save version '{save_version}' did not match " - f"load version '{load_version}' for " - r"kedro_datasets.huggingface.hugging_face_dataset.LocalHFDataset\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - versioned_local_hf_dataset.save(dataset) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - LocalHFDataset( - path="https://example.com/hf_data", - version=Version(None, None), - ) - - def test_save_invalid_type_versioned(self, versioned_local_hf_dataset): - """Check the error when saving an unsupported type through versioned dataset.""" - pattern = r"LocalHFDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." - with pytest.raises(DatasetError, match=pattern): - versioned_local_hf_dataset.save("not a dataset") diff --git a/kedro-datasets/tests/huggingface/test_json_dataset.py b/kedro-datasets/tests/huggingface/test_json_dataset.py new file mode 100644 index 000000000..c44f295c7 --- /dev/null +++ b/kedro-datasets/tests/huggingface/test_json_dataset.py @@ -0,0 +1,182 @@ +from pathlib import PurePosixPath + +import pytest +from datasets import Dataset, DatasetDict +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets.huggingface.json_dataset import JSONDataset + + +@pytest.fixture +def path_json(tmp_path): + return (tmp_path / "test.json").as_posix() + + +@pytest.fixture +def path_json_dir(tmp_path): + return (tmp_path / "test_json_dd").as_posix() + + +@pytest.fixture +def json_dataset(path_json, save_args, load_args, fs_args): + return JSONDataset( + path=path_json, + save_args=save_args, + load_args=load_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def json_dataset_dir(path_json_dir): + return JSONDataset(path=path_json_dir) + + +@pytest.fixture +def versioned_json_dataset(path_json, load_version, save_version): + return JSONDataset(path=path_json, version=Version(load_version, save_version)) + + +class TestJSONDataset: + def test_save_and_load_dataset(self, json_dataset, dataset): + """A single-file load returns a Dataset (auto-unwrapped).""" + json_dataset.save(dataset) + reloaded = json_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_with_split(self, path_json, dataset): + """With split in load_args, the explicit split is respected.""" + ds = JSONDataset(path=path_json, load_args={"split": "train"}) + ds.save(dataset) + reloaded = ds.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, json_dataset_dir, dataset_dict): + json_dataset_dir.save(dataset_dict) + reloaded = json_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_save_and_load_iterable_dataset(self, json_dataset, iterable_dataset): + json_dataset.save(iterable_dataset) + reloaded = json_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, json_dataset_dir, iterable_dataset_dict + ): + json_dataset_dir.save(iterable_dataset_dict) + reloaded = json_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + assert reloaded["train"].to_dict() == { + "col1": [1, 2], + "col2": ["a", "b"], + } + + def test_exists(self, json_dataset, dataset): + assert not json_dataset.exists() + json_dataset.save(dataset) + assert json_dataset.exists() + + def test_load_missing_dataset(self, json_dataset): + pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.json_dataset.JSONDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + json_dataset.load() + + def test_save_invalid_type(self, json_dataset): + pattern = r"JSONDataset only supports" + with pytest.raises(DatasetError, match=pattern): + json_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/data.json", S3FileSystem), + ("file:///tmp/data.json", LocalFileSystem), + ("/tmp/data.json", LocalFileSystem), + ("gcs://bucket/data.json", GCSFileSystem), + ("https://example.com/data.json", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + ds = JSONDataset(path=path) + assert isinstance(ds._fs, instance_type) + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + assert str(ds._filepath) == resolved + assert isinstance(ds._filepath, PurePosixPath) + + def test_pathlike_path(self, tmp_path, dataset): + path = tmp_path / "test_pathlike.json" + ds = JSONDataset(path=path) + ds.save(dataset) + reloaded = ds.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = "test.json" + ds = JSONDataset(path=path) + ds.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + +class TestJSONDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + path = "test.json" + ds = JSONDataset(path=path) + ds_versioned = JSONDataset( + path=path, version=Version(load_version, save_version) + ) + assert path in str(ds) + assert "version" not in str(ds) + + assert path in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "JSONDataset" in str(ds_versioned) + + def test_save_and_load(self, versioned_json_dataset, dataset): + versioned_json_dataset.save(dataset) + reloaded = versioned_json_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_no_versions(self, versioned_json_dataset): + pattern = r"Did not find any versions for kedro_datasets.huggingface.json_dataset.JSONDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.load() + + def test_exists(self, versioned_json_dataset, dataset): + assert not versioned_json_dataset.exists() + versioned_json_dataset.save(dataset) + assert versioned_json_dataset.exists() + + def test_prevent_overwrite(self, versioned_json_dataset, dataset): + versioned_json_dataset.save(dataset) + pattern = ( + r"Save path \'.+\' for kedro_datasets.huggingface.json_dataset.JSONDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.save(dataset) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + with pytest.raises(DatasetError, match=pattern): + JSONDataset( + path="https://example.com/data.json", + version=Version(None, None), + ) diff --git a/kedro-datasets/tests/huggingface/test_lance_dataset.py b/kedro-datasets/tests/huggingface/test_lance_dataset.py new file mode 100644 index 000000000..677a6fbe3 --- /dev/null +++ b/kedro-datasets/tests/huggingface/test_lance_dataset.py @@ -0,0 +1,74 @@ +from pathlib import PurePosixPath + +import pytest +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets.huggingface.lance_dataset import LanceDataset + + +@pytest.fixture +def path_lance(tmp_path): + return (tmp_path / "test.lance").as_posix() + + +@pytest.fixture +def lance_dataset(path_lance): + return LanceDataset(path=path_lance) + + +class TestLanceDataset: + def test_save_dataset_raises(self, lance_dataset, dataset): + with pytest.raises( + DatasetError, match="Saving in lance format is not supported" + ): + lance_dataset.save(dataset) + + def test_save_dataset_dict_raises(self, lance_dataset, dataset_dict): + with pytest.raises( + DatasetError, match="Saving in lance format is not supported" + ): + lance_dataset.save(dataset_dict) + + def test_save_invalid_type(self, lance_dataset): + pattern = r"LanceDataset only supports" + with pytest.raises(DatasetError, match=pattern): + lance_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/data.lance", S3FileSystem), + ("file:///tmp/data.lance", LocalFileSystem), + ("/tmp/data.lance", LocalFileSystem), + ("gcs://bucket/data.lance", GCSFileSystem), + ("https://example.com/data.lance", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + ds = LanceDataset(path=path) + assert isinstance(ds._fs, instance_type) + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + assert str(ds._filepath) == resolved + assert isinstance(ds._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = "test.lance" + ds = LanceDataset(path=path) + ds.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + def test_exists_when_missing(self, lance_dataset): + assert not lance_dataset.exists() + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + with pytest.raises(DatasetError, match=pattern): + LanceDataset( + path="https://example.com/data.lance", + version=Version(None, None), + ) diff --git a/kedro-datasets/tests/huggingface/test_parquet_dataset.py b/kedro-datasets/tests/huggingface/test_parquet_dataset.py new file mode 100644 index 000000000..e867abf74 --- /dev/null +++ b/kedro-datasets/tests/huggingface/test_parquet_dataset.py @@ -0,0 +1,184 @@ +from pathlib import PurePosixPath + +import pytest +from datasets import Dataset, DatasetDict +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets.huggingface.parquet_dataset import ParquetDataset + + +@pytest.fixture +def path_parquet(tmp_path): + return (tmp_path / "test.parquet").as_posix() + + +@pytest.fixture +def path_parquet_dir(tmp_path): + return (tmp_path / "test_parquet_dd").as_posix() + + +@pytest.fixture +def parquet_dataset(path_parquet, save_args, load_args, fs_args): + return ParquetDataset( + path=path_parquet, + save_args=save_args, + load_args=load_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def parquet_dataset_dir(path_parquet_dir): + return ParquetDataset(path=path_parquet_dir) + + +@pytest.fixture +def versioned_parquet_dataset(path_parquet, load_version, save_version): + return ParquetDataset( + path=path_parquet, version=Version(load_version, save_version) + ) + + +class TestParquetDataset: + def test_save_and_load_dataset(self, parquet_dataset, dataset): + """A single-file load returns a Dataset (auto-unwrapped).""" + parquet_dataset.save(dataset) + reloaded = parquet_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_with_split(self, path_parquet, dataset): + """With split in load_args, the explicit split is respected.""" + ds = ParquetDataset(path=path_parquet, load_args={"split": "train"}) + ds.save(dataset) + reloaded = ds.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, parquet_dataset_dir, dataset_dict): + parquet_dataset_dir.save(dataset_dict) + reloaded = parquet_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_save_and_load_iterable_dataset(self, parquet_dataset, iterable_dataset): + parquet_dataset.save(iterable_dataset) + reloaded = parquet_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, parquet_dataset_dir, iterable_dataset_dict + ): + parquet_dataset_dir.save(iterable_dataset_dict) + reloaded = parquet_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + assert reloaded["train"].to_dict() == { + "col1": [1, 2], + "col2": ["a", "b"], + } + + def test_exists(self, parquet_dataset, dataset): + assert not parquet_dataset.exists() + parquet_dataset.save(dataset) + assert parquet_dataset.exists() + + def test_load_missing_dataset(self, parquet_dataset): + pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.parquet_dataset.ParquetDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + parquet_dataset.load() + + def test_save_invalid_type(self, parquet_dataset): + pattern = r"ParquetDataset only supports" + with pytest.raises(DatasetError, match=pattern): + parquet_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/data.parquet", S3FileSystem), + ("file:///tmp/data.parquet", LocalFileSystem), + ("/tmp/data.parquet", LocalFileSystem), + ("gcs://bucket/data.parquet", GCSFileSystem), + ("https://example.com/data.parquet", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + ds = ParquetDataset(path=path) + assert isinstance(ds._fs, instance_type) + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + assert str(ds._filepath) == resolved + assert isinstance(ds._filepath, PurePosixPath) + + def test_pathlike_path(self, tmp_path, dataset): + path = tmp_path / "test_pathlike.parquet" + ds = ParquetDataset(path=path) + ds.save(dataset) + reloaded = ds.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = "test.parquet" + ds = ParquetDataset(path=path) + ds.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + +class TestParquetDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + path = "test.parquet" + ds = ParquetDataset(path=path) + ds_versioned = ParquetDataset( + path=path, version=Version(load_version, save_version) + ) + assert path in str(ds) + assert "version" not in str(ds) + + assert path in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "ParquetDataset" in str(ds_versioned) + + def test_save_and_load(self, versioned_parquet_dataset, dataset): + versioned_parquet_dataset.save(dataset) + reloaded = versioned_parquet_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_no_versions(self, versioned_parquet_dataset): + pattern = r"Did not find any versions for kedro_datasets.huggingface.parquet_dataset.ParquetDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.load() + + def test_exists(self, versioned_parquet_dataset, dataset): + assert not versioned_parquet_dataset.exists() + versioned_parquet_dataset.save(dataset) + assert versioned_parquet_dataset.exists() + + def test_prevent_overwrite(self, versioned_parquet_dataset, dataset): + versioned_parquet_dataset.save(dataset) + pattern = ( + r"Save path \'.+\' for kedro_datasets.huggingface.parquet_dataset.ParquetDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.save(dataset) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + with pytest.raises(DatasetError, match=pattern): + ParquetDataset( + path="https://example.com/data.parquet", + version=Version(None, None), + ) From aee993dd5042af5a0a2ce04dc2358e11b89deb1f Mon Sep 17 00:00:00 2001 From: iwhalen Date: Sat, 11 Apr 2026 11:37:45 -0500 Subject: [PATCH 05/21] Udpate docs. Signed-off-by: iwhalen --- kedro-datasets/RELEASE.md | 2 +- .../api/kedro_datasets/huggingface.ArrowDataset.md | 8 ++++++++ .../api/kedro_datasets/huggingface.CSVDataset.md | 8 ++++++++ .../api/kedro_datasets/huggingface.HDF5Dataset.md | 8 ++++++++ .../api/kedro_datasets/huggingface.JSONDataset.md | 8 ++++++++ .../api/kedro_datasets/huggingface.LanceDataset.md | 8 ++++++++ .../kedro_datasets/huggingface.LocalHFDataset.md | 9 --------- .../kedro_datasets/huggingface.ParquetDataset.md | 8 ++++++++ kedro-datasets/mkdocs.yml | 14 ++++++++++++-- 9 files changed, 61 insertions(+), 12 deletions(-) create mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.ArrowDataset.md create mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.CSVDataset.md create mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md create mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.JSONDataset.md create mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md delete mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md create mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.ParquetDataset.md diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 54e3d75f7..547b68925 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -44,7 +44,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: ## Major features and improvements -- Replaced `huggingface.LocalHFDataset` with per-format classes: `ArrowDataset`, `ParquetDataset`, `JSONDataset`, `CSVDataset`, `LanceDataset`, `HDF5Dataset`. +- Add Hugging Face datasets: `ArrowDataset`, `ParquetDataset`, `JSONDataset`, `CSVDataset`, `LanceDataset`, `HDF5Dataset`. ## Bug fixes and other changes ## Community contributions diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.ArrowDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.ArrowDataset.md new file mode 100644 index 000000000..e06088a90 --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets/huggingface.ArrowDataset.md @@ -0,0 +1,8 @@ +# ArrowDataset + +`ArrowDataset` loads and saves Hugging Face datasets in Arrow format using the `datasets` library. + +::: kedro_datasets.huggingface.ArrowDataset + options: + members: true + show_source: true diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.CSVDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.CSVDataset.md new file mode 100644 index 000000000..e8b7d3fe0 --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets/huggingface.CSVDataset.md @@ -0,0 +1,8 @@ +# CSVDataset + +`CSVDataset` loads and saves Hugging Face datasets in CSV format using the `datasets` library. + +::: kedro_datasets.huggingface.CSVDataset + options: + members: true + show_source: true diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md new file mode 100644 index 000000000..bb68c9aec --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md @@ -0,0 +1,8 @@ +# HDF5Dataset + +`HDF5Dataset` loads and saves Hugging Face datasets in HDF5 format using the `datasets` library. + +::: kedro_datasets.huggingface.HDF5Dataset + options: + members: true + show_source: true diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.JSONDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.JSONDataset.md new file mode 100644 index 000000000..7ec307f07 --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets/huggingface.JSONDataset.md @@ -0,0 +1,8 @@ +# JSONDataset + +`JSONDataset` loads and saves Hugging Face datasets in JSON format using the `datasets` library. + +::: kedro_datasets.huggingface.JSONDataset + options: + members: true + show_source: true diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md new file mode 100644 index 000000000..3d695281b --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md @@ -0,0 +1,8 @@ +# LanceDataset + +`LanceDataset` loads and saves Hugging Face datasets in Lance format using the `datasets` library. + +::: kedro_datasets.huggingface.LanceDataset + options: + members: true + show_source: true diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md deleted file mode 100644 index 8e466a5cb..000000000 --- a/kedro-datasets/docs/api/kedro_datasets/huggingface.LocalHFDataset.md +++ /dev/null @@ -1,9 +0,0 @@ -# LocalHFDataset - -`LocalHFDataset` saves and loads Hugging Face datasets from a filesystem - using the `datasets` library. - -::: kedro_datasets.huggingface.LocalHFDataset - options: - members: true - show_source: true diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.ParquetDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.ParquetDataset.md new file mode 100644 index 000000000..b45cbc6c1 --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets/huggingface.ParquetDataset.md @@ -0,0 +1,8 @@ +# ParquetDataset + +`ParquetDataset` loads and saves Hugging Face datasets in Parquet format using the `datasets` library. + +::: kedro_datasets.huggingface.ParquetDataset + options: + members: true + show_source: true diff --git a/kedro-datasets/mkdocs.yml b/kedro-datasets/mkdocs.yml index 61a8fc3d0..f820fc607 100644 --- a/kedro-datasets/mkdocs.yml +++ b/kedro-datasets/mkdocs.yml @@ -140,9 +140,14 @@ plugins: Machine Learning and AI: - api/kedro_datasets/tensorflow.TensorFlowModelDataset.md: TensorFlow model storage + - api/kedro_datasets/huggingface.ArrowDataset.md: Hugging Face local datasets in Arrow format + - api/kedro_datasets/huggingface.CSVDataset.md: Hugging Face local datasets in CSV format + - api/kedro_datasets/huggingface.HDF5Dataset.md: Hugging Face local datasets in HDF5 format - api/kedro_datasets/huggingface.HFDataset.md: Hugging Face remote datasets integration - api/kedro_datasets/huggingface.HFTransformerPipelineDataset.md: Hugging Face transformer pipelines - - api/kedro_datasets/huggingface.LocalHFDataset.md: Hugging Face local datasets integration + - api/kedro_datasets/huggingface.JSONDataset.md: Hugging Face local datasets in JSON format + - api/kedro_datasets/huggingface.LanceDataset.md: Hugging Face local datasets in Lance format + - api/kedro_datasets/huggingface.ParquetDataset.md: Hugging Face local datasets in Parquet format Visualization and Plotting: - api/kedro_datasets/matplotlib.MatplotlibDataset.md: Matplotlib figure storage @@ -265,9 +270,14 @@ nav: - Holoviews: - holoviews.HoloviewsWriter: api/kedro_datasets/holoviews.HoloviewsWriter.md - Huggingface: + - huggingface.ArrowDataset: api/kedro_datasets/huggingface.ArrowDataset.md + - huggingface.CSVDataset: api/kedro_datasets/huggingface.CSVDataset.md + - huggingface.HDF5Dataset: api/kedro_datasets/huggingface.HDF5Dataset.md - huggingface.HFDataset: api/kedro_datasets/huggingface.HFDataset.md - huggingface.HFTransformerPipelineDataset: api/kedro_datasets/huggingface.HFTransformerPipelineDataset.md - - huggingface.LocalHFDataset: api/kedro_datasets/huggingface.LocalHFDataset.md + - huggingface.JSONDataset: api/kedro_datasets/huggingface.JSONDataset.md + - huggingface.LanceDataset: api/kedro_datasets/huggingface.LanceDataset.md + - huggingface.ParquetDataset: api/kedro_datasets/huggingface.ParquetDataset.md - Ibis: - ibis.FileDataset: api/kedro_datasets/ibis.FileDataset.md - ibis.TableDataset: api/kedro_datasets/ibis.TableDataset.md From 168a7476a3435c2c220008293554d3cc474bcc61 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Tue, 21 Apr 2026 19:33:33 -0500 Subject: [PATCH 06/21] Remove HDF5 and Lance hf datasets. Signed-off-by: iwhalen --- kedro-datasets/RELEASE.md | 2 +- .../kedro_datasets/huggingface.HDF5Dataset.md | 8 -- .../huggingface.LanceDataset.md | 8 -- .../kedro_datasets/huggingface/__init__.py | 12 --- .../huggingface/hdf5_dataset.py | 39 ---------- .../huggingface/lance_dataset.py | 39 ---------- kedro-datasets/mkdocs.yml | 4 - .../jsonschema/kedro-catalog-1.0.0.json | 6 +- .../tests/huggingface/test_hdf5_dataset.py | 74 ------------------- .../tests/huggingface/test_lance_dataset.py | 74 ------------------- 10 files changed, 2 insertions(+), 264 deletions(-) delete mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md delete mode 100644 kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md delete mode 100644 kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py delete mode 100644 kedro-datasets/kedro_datasets/huggingface/lance_dataset.py delete mode 100644 kedro-datasets/tests/huggingface/test_hdf5_dataset.py delete mode 100644 kedro-datasets/tests/huggingface/test_lance_dataset.py diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 547b68925..378a402f3 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -44,7 +44,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: ## Major features and improvements -- Add Hugging Face datasets: `ArrowDataset`, `ParquetDataset`, `JSONDataset`, `CSVDataset`, `LanceDataset`, `HDF5Dataset`. +- Add Hugging Face datasets: `ArrowDataset`, `ParquetDataset`, `JSONDataset`, `CSVDataset`. ## Bug fixes and other changes ## Community contributions diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md deleted file mode 100644 index bb68c9aec..000000000 --- a/kedro-datasets/docs/api/kedro_datasets/huggingface.HDF5Dataset.md +++ /dev/null @@ -1,8 +0,0 @@ -# HDF5Dataset - -`HDF5Dataset` loads and saves Hugging Face datasets in HDF5 format using the `datasets` library. - -::: kedro_datasets.huggingface.HDF5Dataset - options: - members: true - show_source: true diff --git a/kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md b/kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md deleted file mode 100644 index 3d695281b..000000000 --- a/kedro-datasets/docs/api/kedro_datasets/huggingface.LanceDataset.md +++ /dev/null @@ -1,8 +0,0 @@ -# LanceDataset - -`LanceDataset` loads and saves Hugging Face datasets in Lance format using the `datasets` library. - -::: kedro_datasets.huggingface.LanceDataset - options: - members: true - show_source: true diff --git a/kedro-datasets/kedro_datasets/huggingface/__init__.py b/kedro-datasets/kedro_datasets/huggingface/__init__.py index f2327b485..bb0b94b44 100644 --- a/kedro-datasets/kedro_datasets/huggingface/__init__.py +++ b/kedro-datasets/kedro_datasets/huggingface/__init__.py @@ -31,16 +31,6 @@ except (ImportError, RuntimeError): CSVDataset: Any -try: - from .lance_dataset import LanceDataset -except (ImportError, RuntimeError): - LanceDataset: Any - -try: - from .hdf5_dataset import HDF5Dataset -except (ImportError, RuntimeError): - HDF5Dataset: Any - try: from .transformer_pipeline_dataset import HFTransformerPipelineDataset except (ImportError, RuntimeError): @@ -56,8 +46,6 @@ "parquet_dataset": ["ParquetDataset"], "json_dataset": ["JSONDataset"], "csv_dataset": ["CSVDataset"], - "lance_dataset": ["LanceDataset"], - "hdf5_dataset": ["HDF5Dataset"], "transformer_pipeline_dataset": ["HFTransformerPipelineDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py deleted file mode 100644 index fb5b3efad..000000000 --- a/kedro-datasets/kedro_datasets/huggingface/hdf5_dataset.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from typing import ClassVar - -from datasets import Dataset, DatasetDict -from kedro.io.core import DatasetError - -from ._base import FilesystemDataset - - -class HDF5Dataset(FilesystemDataset): - """``HDF5Dataset`` loads Hugging Face ``Dataset`` and - ``DatasetDict`` objects from - `HDF5 `_ files. - - Saving is **not** supported because the ``datasets`` library does - not provide a ``to_hdf5`` method. - - Examples: - Using the - [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): - - ```yaml - reviews: - type: huggingface.HDF5Dataset - path: data/01_raw/reviews.h5 - ``` - """ - - BUILDER: ClassVar[str] = "hdf5" - EXTENSION: ClassVar[str] = ".h5" - - def _save_dataset(self, data: Dataset, save_path: str) -> None: - msg = "Saving in hdf5 format is not supported by the Hugging Face datasets library." - raise DatasetError(msg) - - def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: - msg = "Saving in hdf5 format is not supported by the Hugging Face datasets library." - raise DatasetError(msg) diff --git a/kedro-datasets/kedro_datasets/huggingface/lance_dataset.py b/kedro-datasets/kedro_datasets/huggingface/lance_dataset.py deleted file mode 100644 index c8cb6c139..000000000 --- a/kedro-datasets/kedro_datasets/huggingface/lance_dataset.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from typing import ClassVar - -from datasets import Dataset, DatasetDict -from kedro.io.core import DatasetError - -from ._base import FilesystemDataset - - -class LanceDataset(FilesystemDataset): - """``LanceDataset`` loads Hugging Face ``Dataset`` and - ``DatasetDict`` objects from `Lance `_ - files. - - Saving is **not** supported because the ``datasets`` library does - not provide a ``to_lance`` method. - - Examples: - Using the - [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): - - ```yaml - reviews: - type: huggingface.LanceDataset - path: data/01_raw/reviews.lance - ``` - """ - - BUILDER: ClassVar[str] = "lance" - EXTENSION: ClassVar[str] = ".lance" - - def _save_dataset(self, data: Dataset, save_path: str) -> None: - msg = "Saving in lance format is not supported by the Hugging Face datasets library." - raise DatasetError(msg) - - def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: - msg = "Saving in lance format is not supported by the Hugging Face datasets library." - raise DatasetError(msg) diff --git a/kedro-datasets/mkdocs.yml b/kedro-datasets/mkdocs.yml index f820fc607..7dc066dbd 100644 --- a/kedro-datasets/mkdocs.yml +++ b/kedro-datasets/mkdocs.yml @@ -142,11 +142,9 @@ plugins: - api/kedro_datasets/tensorflow.TensorFlowModelDataset.md: TensorFlow model storage - api/kedro_datasets/huggingface.ArrowDataset.md: Hugging Face local datasets in Arrow format - api/kedro_datasets/huggingface.CSVDataset.md: Hugging Face local datasets in CSV format - - api/kedro_datasets/huggingface.HDF5Dataset.md: Hugging Face local datasets in HDF5 format - api/kedro_datasets/huggingface.HFDataset.md: Hugging Face remote datasets integration - api/kedro_datasets/huggingface.HFTransformerPipelineDataset.md: Hugging Face transformer pipelines - api/kedro_datasets/huggingface.JSONDataset.md: Hugging Face local datasets in JSON format - - api/kedro_datasets/huggingface.LanceDataset.md: Hugging Face local datasets in Lance format - api/kedro_datasets/huggingface.ParquetDataset.md: Hugging Face local datasets in Parquet format Visualization and Plotting: @@ -272,11 +270,9 @@ nav: - Huggingface: - huggingface.ArrowDataset: api/kedro_datasets/huggingface.ArrowDataset.md - huggingface.CSVDataset: api/kedro_datasets/huggingface.CSVDataset.md - - huggingface.HDF5Dataset: api/kedro_datasets/huggingface.HDF5Dataset.md - huggingface.HFDataset: api/kedro_datasets/huggingface.HFDataset.md - huggingface.HFTransformerPipelineDataset: api/kedro_datasets/huggingface.HFTransformerPipelineDataset.md - huggingface.JSONDataset: api/kedro_datasets/huggingface.JSONDataset.md - - huggingface.LanceDataset: api/kedro_datasets/huggingface.LanceDataset.md - huggingface.ParquetDataset: api/kedro_datasets/huggingface.ParquetDataset.md - Ibis: - ibis.FileDataset: api/kedro_datasets/ibis.FileDataset.md diff --git a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json index 519692e5c..bee9ca925 100644 --- a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json +++ b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json @@ -19,11 +19,9 @@ "holoviews.HoloviewsWriter", "huggingface.ArrowDataset", "huggingface.CSVDataset", - "huggingface.HDF5Dataset", "huggingface.HFDataset", "huggingface.HFTransformerPipelineDataset", "huggingface.JSONDataset", - "huggingface.LanceDataset", "huggingface.ParquetDataset", "ibis.FileDataset", "ibis.TableDataset", @@ -465,9 +463,7 @@ "huggingface.ArrowDataset", "huggingface.ParquetDataset", "huggingface.JSONDataset", - "huggingface.CSVDataset", - "huggingface.LanceDataset", - "huggingface.HDF5Dataset" + "huggingface.CSVDataset" ] } } diff --git a/kedro-datasets/tests/huggingface/test_hdf5_dataset.py b/kedro-datasets/tests/huggingface/test_hdf5_dataset.py deleted file mode 100644 index a83e5367a..000000000 --- a/kedro-datasets/tests/huggingface/test_hdf5_dataset.py +++ /dev/null @@ -1,74 +0,0 @@ -from pathlib import PurePosixPath - -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version -from s3fs.core import S3FileSystem - -from kedro_datasets.huggingface.hdf5_dataset import HDF5Dataset - - -@pytest.fixture -def path_hdf5(tmp_path): - return (tmp_path / "test.h5").as_posix() - - -@pytest.fixture -def hdf5_dataset(path_hdf5): - return HDF5Dataset(path=path_hdf5) - - -class TestHDF5Dataset: - def test_save_dataset_raises(self, hdf5_dataset, dataset): - with pytest.raises( - DatasetError, match="Saving in hdf5 format is not supported" - ): - hdf5_dataset.save(dataset) - - def test_save_dataset_dict_raises(self, hdf5_dataset, dataset_dict): - with pytest.raises( - DatasetError, match="Saving in hdf5 format is not supported" - ): - hdf5_dataset.save(dataset_dict) - - def test_save_invalid_type(self, hdf5_dataset): - pattern = r"HDF5Dataset only supports" - with pytest.raises(DatasetError, match=pattern): - hdf5_dataset.save({"not": "a dataset"}) - - @pytest.mark.parametrize( - "path,instance_type", - [ - ("s3://bucket/data.h5", S3FileSystem), - ("file:///tmp/data.h5", LocalFileSystem), - ("/tmp/data.h5", LocalFileSystem), - ("gcs://bucket/data.h5", GCSFileSystem), - ("https://example.com/data.h5", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, path, instance_type): - ds = HDF5Dataset(path=path) - assert isinstance(ds._fs, instance_type) - resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(ds._filepath) == resolved - assert isinstance(ds._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - path = "test.h5" - ds = HDF5Dataset(path=path) - ds.release() - fs_mock.invalidate_cache.assert_called_once_with(path) - - def test_exists_when_missing(self, hdf5_dataset): - assert not hdf5_dataset.exists() - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DatasetError, match=pattern): - HDF5Dataset( - path="https://example.com/data.h5", - version=Version(None, None), - ) diff --git a/kedro-datasets/tests/huggingface/test_lance_dataset.py b/kedro-datasets/tests/huggingface/test_lance_dataset.py deleted file mode 100644 index 677a6fbe3..000000000 --- a/kedro-datasets/tests/huggingface/test_lance_dataset.py +++ /dev/null @@ -1,74 +0,0 @@ -from pathlib import PurePosixPath - -import pytest -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version -from s3fs.core import S3FileSystem - -from kedro_datasets.huggingface.lance_dataset import LanceDataset - - -@pytest.fixture -def path_lance(tmp_path): - return (tmp_path / "test.lance").as_posix() - - -@pytest.fixture -def lance_dataset(path_lance): - return LanceDataset(path=path_lance) - - -class TestLanceDataset: - def test_save_dataset_raises(self, lance_dataset, dataset): - with pytest.raises( - DatasetError, match="Saving in lance format is not supported" - ): - lance_dataset.save(dataset) - - def test_save_dataset_dict_raises(self, lance_dataset, dataset_dict): - with pytest.raises( - DatasetError, match="Saving in lance format is not supported" - ): - lance_dataset.save(dataset_dict) - - def test_save_invalid_type(self, lance_dataset): - pattern = r"LanceDataset only supports" - with pytest.raises(DatasetError, match=pattern): - lance_dataset.save({"not": "a dataset"}) - - @pytest.mark.parametrize( - "path,instance_type", - [ - ("s3://bucket/data.lance", S3FileSystem), - ("file:///tmp/data.lance", LocalFileSystem), - ("/tmp/data.lance", LocalFileSystem), - ("gcs://bucket/data.lance", GCSFileSystem), - ("https://example.com/data.lance", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, path, instance_type): - ds = LanceDataset(path=path) - assert isinstance(ds._fs, instance_type) - resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(ds._filepath) == resolved - assert isinstance(ds._filepath, PurePosixPath) - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - path = "test.lance" - ds = LanceDataset(path=path) - ds.release() - fs_mock.invalidate_cache.assert_called_once_with(path) - - def test_exists_when_missing(self, lance_dataset): - assert not lance_dataset.exists() - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DatasetError, match=pattern): - LanceDataset( - path="https://example.com/data.lance", - version=Version(None, None), - ) From de65093517b49f2a7b57abba3e7102081e398aa3 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Tue, 21 Apr 2026 20:33:01 -0500 Subject: [PATCH 07/21] Format tests. Signed-off-by: iwhalen --- .../tests/huggingface/test_csv_dataset.py | 182 --------------- .../huggingface/test_filesystem_datasets.py | 212 ++++++++++++++++++ .../tests/huggingface/test_json_dataset.py | 182 --------------- .../tests/huggingface/test_parquet_dataset.py | 184 --------------- 4 files changed, 212 insertions(+), 548 deletions(-) delete mode 100644 kedro-datasets/tests/huggingface/test_csv_dataset.py create mode 100644 kedro-datasets/tests/huggingface/test_filesystem_datasets.py delete mode 100644 kedro-datasets/tests/huggingface/test_json_dataset.py delete mode 100644 kedro-datasets/tests/huggingface/test_parquet_dataset.py diff --git a/kedro-datasets/tests/huggingface/test_csv_dataset.py b/kedro-datasets/tests/huggingface/test_csv_dataset.py deleted file mode 100644 index 700b7307e..000000000 --- a/kedro-datasets/tests/huggingface/test_csv_dataset.py +++ /dev/null @@ -1,182 +0,0 @@ -from pathlib import PurePosixPath - -import pytest -from datasets import Dataset, DatasetDict -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version -from s3fs.core import S3FileSystem - -from kedro_datasets.huggingface.csv_dataset import CSVDataset - - -@pytest.fixture -def path_csv(tmp_path): - return (tmp_path / "test.csv").as_posix() - - -@pytest.fixture -def path_csv_dir(tmp_path): - return (tmp_path / "test_csv_dd").as_posix() - - -@pytest.fixture -def csv_dataset(path_csv, save_args, load_args, fs_args): - return CSVDataset( - path=path_csv, - save_args=save_args, - load_args=load_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def csv_dataset_dir(path_csv_dir): - return CSVDataset(path=path_csv_dir) - - -@pytest.fixture -def versioned_csv_dataset(path_csv, load_version, save_version): - return CSVDataset(path=path_csv, version=Version(load_version, save_version)) - - -class TestCSVDataset: - def test_save_and_load_dataset(self, csv_dataset, dataset): - """A single-file load returns a Dataset (auto-unwrapped).""" - csv_dataset.save(dataset) - reloaded = csv_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_with_split(self, path_csv, dataset): - """With split in load_args, the explicit split is respected.""" - ds = CSVDataset(path=path_csv, load_args={"split": "train"}) - ds.save(dataset) - reloaded = ds.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_dict(self, csv_dataset_dir, dataset_dict): - csv_dataset_dir.save(dataset_dict) - reloaded = csv_dataset_dir.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - - def test_save_and_load_iterable_dataset(self, csv_dataset, iterable_dataset): - csv_dataset.save(iterable_dataset) - reloaded = csv_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } - - def test_save_and_load_iterable_dataset_dict( - self, csv_dataset_dir, iterable_dataset_dict - ): - csv_dataset_dir.save(iterable_dataset_dict) - reloaded = csv_dataset_dir.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - assert reloaded["train"].to_dict() == { - "col1": [1, 2], - "col2": ["a", "b"], - } - - def test_exists(self, csv_dataset, dataset): - assert not csv_dataset.exists() - csv_dataset.save(dataset) - assert csv_dataset.exists() - - def test_load_missing_dataset(self, csv_dataset): - pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.csv_dataset.CSVDataset\(.*\)" - with pytest.raises(DatasetError, match=pattern): - csv_dataset.load() - - def test_save_invalid_type(self, csv_dataset): - pattern = r"CSVDataset only supports" - with pytest.raises(DatasetError, match=pattern): - csv_dataset.save({"not": "a dataset"}) - - @pytest.mark.parametrize( - "path,instance_type", - [ - ("s3://bucket/data.csv", S3FileSystem), - ("file:///tmp/data.csv", LocalFileSystem), - ("/tmp/data.csv", LocalFileSystem), - ("gcs://bucket/data.csv", GCSFileSystem), - ("https://example.com/data.csv", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, path, instance_type): - ds = CSVDataset(path=path) - assert isinstance(ds._fs, instance_type) - resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(ds._filepath) == resolved - assert isinstance(ds._filepath, PurePosixPath) - - def test_pathlike_path(self, tmp_path, dataset): - path = tmp_path / "test_pathlike.csv" - ds = CSVDataset(path=path) - ds.save(dataset) - reloaded = ds.load() - assert reloaded.to_dict() == dataset.to_dict() - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - path = "test.csv" - ds = CSVDataset(path=path) - ds.release() - fs_mock.invalidate_cache.assert_called_once_with(path) - - -class TestCSVDatasetVersioned: - def test_version_str_repr(self, load_version, save_version): - path = "test.csv" - ds = CSVDataset(path=path) - ds_versioned = CSVDataset( - path=path, version=Version(load_version, save_version) - ) - assert path in str(ds) - assert "version" not in str(ds) - - assert path in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "CSVDataset" in str(ds_versioned) - - def test_save_and_load(self, versioned_csv_dataset, dataset): - versioned_csv_dataset.save(dataset) - reloaded = versioned_csv_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_no_versions(self, versioned_csv_dataset): - pattern = r"Did not find any versions for kedro_datasets.huggingface.csv_dataset.CSVDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_csv_dataset.load() - - def test_exists(self, versioned_csv_dataset, dataset): - assert not versioned_csv_dataset.exists() - versioned_csv_dataset.save(dataset) - assert versioned_csv_dataset.exists() - - def test_prevent_overwrite(self, versioned_csv_dataset, dataset): - versioned_csv_dataset.save(dataset) - pattern = ( - r"Save path \'.+\' for kedro_datasets.huggingface.csv_dataset.CSVDataset\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_csv_dataset.save(dataset) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DatasetError, match=pattern): - CSVDataset( - path="https://example.com/data.csv", - version=Version(None, None), - ) diff --git a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py new file mode 100644 index 000000000..159feec06 --- /dev/null +++ b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py @@ -0,0 +1,212 @@ +import re +from pathlib import PurePosixPath + +import pytest +from datasets import Dataset, DatasetDict +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets.huggingface.csv_dataset import CSVDataset +from kedro_datasets.huggingface.json_dataset import JSONDataset +from kedro_datasets.huggingface.parquet_dataset import ParquetDataset + +FORMATS = [ + pytest.param((CSVDataset, ".csv"), id="csv"), + pytest.param((JSONDataset, ".json"), id="json"), + pytest.param((ParquetDataset, ".parquet"), id="parquet"), +] + +PROTOCOLS = [ + ("s3://bucket/data", S3FileSystem), + ("file:///tmp/data", LocalFileSystem), + ("/tmp/data", LocalFileSystem), + ("gcs://bucket/data", GCSFileSystem), + ("https://example.com/data", HTTPFileSystem), +] + + +def _qualname(cls) -> str: + return re.escape(f"{cls.__module__}.{cls.__name__}") + + +@pytest.fixture(params=FORMATS) +def fmt(request): + return request.param + + +@pytest.fixture +def dataset_cls(fmt): + return fmt[0] + + +@pytest.fixture +def extension(fmt): + return fmt[1] + + +@pytest.fixture +def path_file(tmp_path, extension): + return (tmp_path / f"test{extension}").as_posix() + + +@pytest.fixture +def path_dir(tmp_path): + return (tmp_path / "test_dd").as_posix() + + +@pytest.fixture +def fs_dataset(dataset_cls, path_file, save_args, load_args, fs_args): + return dataset_cls( + path=path_file, + save_args=save_args, + load_args=load_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def fs_dataset_dir(dataset_cls, path_dir): + return dataset_cls(path=path_dir) + + +@pytest.fixture +def versioned_fs_dataset(dataset_cls, path_file, load_version, save_version): + return dataset_cls(path=path_file, version=Version(load_version, save_version)) + + +class TestFilesystemDataset: + def test_save_and_load_dataset(self, fs_dataset, dataset): + """A single-file load returns a Dataset (auto-unwrapped).""" + fs_dataset.save(dataset) + reloaded = fs_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_with_split(self, dataset_cls, path_file, dataset): + """With split in load_args, the explicit split is respected.""" + ds = dataset_cls(path=path_file, load_args={"split": "train"}) + ds.save(dataset) + reloaded = ds.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_save_and_load_dataset_dict(self, fs_dataset_dir, dataset_dict): + fs_dataset_dir.save(dataset_dict) + reloaded = fs_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + for split in dataset_dict: + assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + + def test_save_and_load_iterable_dataset(self, fs_dataset, iterable_dataset): + fs_dataset.save(iterable_dataset) + reloaded = fs_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + } + + def test_save_and_load_iterable_dataset_dict( + self, fs_dataset_dir, iterable_dataset_dict + ): + fs_dataset_dir.save(iterable_dataset_dict) + reloaded = fs_dataset_dir.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == {"train", "test"} + assert reloaded["train"].to_dict() == { + "col1": [1, 2], + "col2": ["a", "b"], + } + + def test_exists(self, fs_dataset, dataset): + assert not fs_dataset.exists() + fs_dataset.save(dataset) + assert fs_dataset.exists() + + def test_load_missing_dataset(self, fs_dataset, dataset_cls): + pattern = ( + rf"Failed while loading data from dataset {_qualname(dataset_cls)}\(.*\)" + ) + with pytest.raises(DatasetError, match=pattern): + fs_dataset.load() + + def test_save_invalid_type(self, fs_dataset, dataset_cls): + pattern = rf"{dataset_cls.__name__} only supports" + with pytest.raises(DatasetError, match=pattern): + fs_dataset.save({"not": "a dataset"}) + + @pytest.mark.parametrize("base_path,instance_type", PROTOCOLS) + def test_protocol_usage(self, dataset_cls, extension, base_path, instance_type): + path = f"{base_path}{extension}" + ds = dataset_cls(path=path) + assert isinstance(ds._fs, instance_type) + resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] + assert str(ds._filepath) == resolved + assert isinstance(ds._filepath, PurePosixPath) + + def test_pathlike_path(self, dataset_cls, tmp_path, extension, dataset): + path = tmp_path / f"test_pathlike{extension}" + ds = dataset_cls(path=path) + ds.save(dataset) + reloaded = ds.load() + assert reloaded.to_dict() == dataset.to_dict() + + def test_catalog_release(self, dataset_cls, extension, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + path = f"test{extension}" + ds = dataset_cls(path=path) + ds.release() + fs_mock.invalidate_cache.assert_called_once_with(path) + + +class TestFilesystemDatasetVersioned: + def test_version_str_repr(self, dataset_cls, extension, load_version, save_version): + path = f"test{extension}" + ds = dataset_cls(path=path) + ds_versioned = dataset_cls( + path=path, version=Version(load_version, save_version) + ) + assert path in str(ds) + assert "version" not in str(ds) + + assert path in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert dataset_cls.__name__ in str(ds_versioned) + + def test_save_and_load(self, versioned_fs_dataset, dataset): + versioned_fs_dataset.save(dataset) + reloaded = versioned_fs_dataset.load() + assert isinstance(reloaded, Dataset) + assert reloaded.to_dict() == dataset.to_dict() + + def test_no_versions(self, versioned_fs_dataset, dataset_cls): + pattern = rf"Did not find any versions for {_qualname(dataset_cls)}\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_fs_dataset.load() + + def test_exists(self, versioned_fs_dataset, dataset): + assert not versioned_fs_dataset.exists() + versioned_fs_dataset.save(dataset) + assert versioned_fs_dataset.exists() + + def test_prevent_overwrite(self, versioned_fs_dataset, dataset_cls, dataset): + versioned_fs_dataset.save(dataset) + pattern = ( + rf"Save path \'.+\' for {_qualname(dataset_cls)}\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_fs_dataset.save(dataset) + + def test_http_filesystem_no_versioning(self, dataset_cls, extension): + pattern = "Versioning is not supported for HTTP protocols." + with pytest.raises(DatasetError, match=pattern): + dataset_cls( + path=f"https://example.com/data{extension}", + version=Version(None, None), + ) diff --git a/kedro-datasets/tests/huggingface/test_json_dataset.py b/kedro-datasets/tests/huggingface/test_json_dataset.py deleted file mode 100644 index c44f295c7..000000000 --- a/kedro-datasets/tests/huggingface/test_json_dataset.py +++ /dev/null @@ -1,182 +0,0 @@ -from pathlib import PurePosixPath - -import pytest -from datasets import Dataset, DatasetDict -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version -from s3fs.core import S3FileSystem - -from kedro_datasets.huggingface.json_dataset import JSONDataset - - -@pytest.fixture -def path_json(tmp_path): - return (tmp_path / "test.json").as_posix() - - -@pytest.fixture -def path_json_dir(tmp_path): - return (tmp_path / "test_json_dd").as_posix() - - -@pytest.fixture -def json_dataset(path_json, save_args, load_args, fs_args): - return JSONDataset( - path=path_json, - save_args=save_args, - load_args=load_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def json_dataset_dir(path_json_dir): - return JSONDataset(path=path_json_dir) - - -@pytest.fixture -def versioned_json_dataset(path_json, load_version, save_version): - return JSONDataset(path=path_json, version=Version(load_version, save_version)) - - -class TestJSONDataset: - def test_save_and_load_dataset(self, json_dataset, dataset): - """A single-file load returns a Dataset (auto-unwrapped).""" - json_dataset.save(dataset) - reloaded = json_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_with_split(self, path_json, dataset): - """With split in load_args, the explicit split is respected.""" - ds = JSONDataset(path=path_json, load_args={"split": "train"}) - ds.save(dataset) - reloaded = ds.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_dict(self, json_dataset_dir, dataset_dict): - json_dataset_dir.save(dataset_dict) - reloaded = json_dataset_dir.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - - def test_save_and_load_iterable_dataset(self, json_dataset, iterable_dataset): - json_dataset.save(iterable_dataset) - reloaded = json_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } - - def test_save_and_load_iterable_dataset_dict( - self, json_dataset_dir, iterable_dataset_dict - ): - json_dataset_dir.save(iterable_dataset_dict) - reloaded = json_dataset_dir.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - assert reloaded["train"].to_dict() == { - "col1": [1, 2], - "col2": ["a", "b"], - } - - def test_exists(self, json_dataset, dataset): - assert not json_dataset.exists() - json_dataset.save(dataset) - assert json_dataset.exists() - - def test_load_missing_dataset(self, json_dataset): - pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.json_dataset.JSONDataset\(.*\)" - with pytest.raises(DatasetError, match=pattern): - json_dataset.load() - - def test_save_invalid_type(self, json_dataset): - pattern = r"JSONDataset only supports" - with pytest.raises(DatasetError, match=pattern): - json_dataset.save({"not": "a dataset"}) - - @pytest.mark.parametrize( - "path,instance_type", - [ - ("s3://bucket/data.json", S3FileSystem), - ("file:///tmp/data.json", LocalFileSystem), - ("/tmp/data.json", LocalFileSystem), - ("gcs://bucket/data.json", GCSFileSystem), - ("https://example.com/data.json", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, path, instance_type): - ds = JSONDataset(path=path) - assert isinstance(ds._fs, instance_type) - resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(ds._filepath) == resolved - assert isinstance(ds._filepath, PurePosixPath) - - def test_pathlike_path(self, tmp_path, dataset): - path = tmp_path / "test_pathlike.json" - ds = JSONDataset(path=path) - ds.save(dataset) - reloaded = ds.load() - assert reloaded.to_dict() == dataset.to_dict() - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - path = "test.json" - ds = JSONDataset(path=path) - ds.release() - fs_mock.invalidate_cache.assert_called_once_with(path) - - -class TestJSONDatasetVersioned: - def test_version_str_repr(self, load_version, save_version): - path = "test.json" - ds = JSONDataset(path=path) - ds_versioned = JSONDataset( - path=path, version=Version(load_version, save_version) - ) - assert path in str(ds) - assert "version" not in str(ds) - - assert path in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "JSONDataset" in str(ds_versioned) - - def test_save_and_load(self, versioned_json_dataset, dataset): - versioned_json_dataset.save(dataset) - reloaded = versioned_json_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_no_versions(self, versioned_json_dataset): - pattern = r"Did not find any versions for kedro_datasets.huggingface.json_dataset.JSONDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_json_dataset.load() - - def test_exists(self, versioned_json_dataset, dataset): - assert not versioned_json_dataset.exists() - versioned_json_dataset.save(dataset) - assert versioned_json_dataset.exists() - - def test_prevent_overwrite(self, versioned_json_dataset, dataset): - versioned_json_dataset.save(dataset) - pattern = ( - r"Save path \'.+\' for kedro_datasets.huggingface.json_dataset.JSONDataset\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_json_dataset.save(dataset) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DatasetError, match=pattern): - JSONDataset( - path="https://example.com/data.json", - version=Version(None, None), - ) diff --git a/kedro-datasets/tests/huggingface/test_parquet_dataset.py b/kedro-datasets/tests/huggingface/test_parquet_dataset.py deleted file mode 100644 index e867abf74..000000000 --- a/kedro-datasets/tests/huggingface/test_parquet_dataset.py +++ /dev/null @@ -1,184 +0,0 @@ -from pathlib import PurePosixPath - -import pytest -from datasets import Dataset, DatasetDict -from fsspec.implementations.http import HTTPFileSystem -from fsspec.implementations.local import LocalFileSystem -from gcsfs import GCSFileSystem -from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version -from s3fs.core import S3FileSystem - -from kedro_datasets.huggingface.parquet_dataset import ParquetDataset - - -@pytest.fixture -def path_parquet(tmp_path): - return (tmp_path / "test.parquet").as_posix() - - -@pytest.fixture -def path_parquet_dir(tmp_path): - return (tmp_path / "test_parquet_dd").as_posix() - - -@pytest.fixture -def parquet_dataset(path_parquet, save_args, load_args, fs_args): - return ParquetDataset( - path=path_parquet, - save_args=save_args, - load_args=load_args, - fs_args=fs_args, - ) - - -@pytest.fixture -def parquet_dataset_dir(path_parquet_dir): - return ParquetDataset(path=path_parquet_dir) - - -@pytest.fixture -def versioned_parquet_dataset(path_parquet, load_version, save_version): - return ParquetDataset( - path=path_parquet, version=Version(load_version, save_version) - ) - - -class TestParquetDataset: - def test_save_and_load_dataset(self, parquet_dataset, dataset): - """A single-file load returns a Dataset (auto-unwrapped).""" - parquet_dataset.save(dataset) - reloaded = parquet_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_with_split(self, path_parquet, dataset): - """With split in load_args, the explicit split is respected.""" - ds = ParquetDataset(path=path_parquet, load_args={"split": "train"}) - ds.save(dataset) - reloaded = ds.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_save_and_load_dataset_dict(self, parquet_dataset_dir, dataset_dict): - parquet_dataset_dir.save(dataset_dict) - reloaded = parquet_dataset_dir.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - - def test_save_and_load_iterable_dataset(self, parquet_dataset, iterable_dataset): - parquet_dataset.save(iterable_dataset) - reloaded = parquet_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } - - def test_save_and_load_iterable_dataset_dict( - self, parquet_dataset_dir, iterable_dataset_dict - ): - parquet_dataset_dir.save(iterable_dataset_dict) - reloaded = parquet_dataset_dir.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - assert reloaded["train"].to_dict() == { - "col1": [1, 2], - "col2": ["a", "b"], - } - - def test_exists(self, parquet_dataset, dataset): - assert not parquet_dataset.exists() - parquet_dataset.save(dataset) - assert parquet_dataset.exists() - - def test_load_missing_dataset(self, parquet_dataset): - pattern = r"Failed while loading data from dataset kedro_datasets.huggingface.parquet_dataset.ParquetDataset\(.*\)" - with pytest.raises(DatasetError, match=pattern): - parquet_dataset.load() - - def test_save_invalid_type(self, parquet_dataset): - pattern = r"ParquetDataset only supports" - with pytest.raises(DatasetError, match=pattern): - parquet_dataset.save({"not": "a dataset"}) - - @pytest.mark.parametrize( - "path,instance_type", - [ - ("s3://bucket/data.parquet", S3FileSystem), - ("file:///tmp/data.parquet", LocalFileSystem), - ("/tmp/data.parquet", LocalFileSystem), - ("gcs://bucket/data.parquet", GCSFileSystem), - ("https://example.com/data.parquet", HTTPFileSystem), - ], - ) - def test_protocol_usage(self, path, instance_type): - ds = ParquetDataset(path=path) - assert isinstance(ds._fs, instance_type) - resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(ds._filepath) == resolved - assert isinstance(ds._filepath, PurePosixPath) - - def test_pathlike_path(self, tmp_path, dataset): - path = tmp_path / "test_pathlike.parquet" - ds = ParquetDataset(path=path) - ds.save(dataset) - reloaded = ds.load() - assert reloaded.to_dict() == dataset.to_dict() - - def test_catalog_release(self, mocker): - fs_mock = mocker.patch("fsspec.filesystem").return_value - path = "test.parquet" - ds = ParquetDataset(path=path) - ds.release() - fs_mock.invalidate_cache.assert_called_once_with(path) - - -class TestParquetDatasetVersioned: - def test_version_str_repr(self, load_version, save_version): - path = "test.parquet" - ds = ParquetDataset(path=path) - ds_versioned = ParquetDataset( - path=path, version=Version(load_version, save_version) - ) - assert path in str(ds) - assert "version" not in str(ds) - - assert path in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" - assert ver_str in str(ds_versioned) - assert "ParquetDataset" in str(ds_versioned) - - def test_save_and_load(self, versioned_parquet_dataset, dataset): - versioned_parquet_dataset.save(dataset) - reloaded = versioned_parquet_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() - - def test_no_versions(self, versioned_parquet_dataset): - pattern = r"Did not find any versions for kedro_datasets.huggingface.parquet_dataset.ParquetDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_parquet_dataset.load() - - def test_exists(self, versioned_parquet_dataset, dataset): - assert not versioned_parquet_dataset.exists() - versioned_parquet_dataset.save(dataset) - assert versioned_parquet_dataset.exists() - - def test_prevent_overwrite(self, versioned_parquet_dataset, dataset): - versioned_parquet_dataset.save(dataset) - pattern = ( - r"Save path \'.+\' for kedro_datasets.huggingface.parquet_dataset.ParquetDataset\(.+\) must " - r"not exist if versioning is enabled\." - ) - with pytest.raises(DatasetError, match=pattern): - versioned_parquet_dataset.save(dataset) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DatasetError, match=pattern): - ParquetDataset( - path="https://example.com/data.parquet", - version=Version(None, None), - ) From 011c9c52328228fc5b455b5cc08deb45a5610bb3 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Tue, 21 Apr 2026 21:17:39 -0500 Subject: [PATCH 08/21] Add non-functioning FilesystemDataset changes. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/_base.py | 51 +++++------------- .../tests/huggingface/test_arrow_dataset.py | 54 ++++++++----------- .../huggingface/test_filesystem_datasets.py | 19 ++----- 3 files changed, 40 insertions(+), 84 deletions(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/_base.py b/kedro-datasets/kedro_datasets/huggingface/_base.py index 59877084b..104ec9fc0 100644 --- a/kedro-datasets/kedro_datasets/huggingface/_base.py +++ b/kedro-datasets/kedro_datasets/huggingface/_base.py @@ -96,24 +96,22 @@ def load(self) -> DatasetLike: return self._load_dataset(load_path) def save(self, data: DatasetLike) -> None: - if not isinstance( - data, - Dataset | DatasetDict | IterableDataset | IterableDatasetDict, - ): + if isinstance(data, IterableDataset | IterableDatasetDict): + msg = ( + f"{type(self).__name__} got iterable dataset. " + "Before saving an iterable dataset " + "you must materialize it into a `Dataset` or `DatasetDict`." + ) + raise RuntimeError(msg) + + if not isinstance(data, Dataset | DatasetDict): msg = ( f"{type(self).__name__} only supports `datasets.Dataset`, " "`datasets.DatasetDict`, " - "`datasets.IterableDataset`, and " - "`datasets.IterableDatasetDict` instances. " f"Got {type(data)}" ) raise DatasetError(msg) - if isinstance(data, IterableDatasetDict): - data = DatasetDict({k: Dataset.from_list(list(v)) for k, v in data.items()}) - elif isinstance(data, IterableDataset): - data = Dataset.from_list(list(data)) - save_path = get_filepath_str(self._get_save_path(), self._protocol) if isinstance(data, DatasetDict): @@ -124,38 +122,13 @@ def save(self, data: DatasetLike) -> None: self._invalidate_cache() def _load_dataset(self, load_path: str) -> DatasetLike: - if self._fs.isdir(load_path): - ext = self.EXTENSION - data_files = { - PurePosixPath(p).stem: p for p in self._fs.glob(f"{load_path}/*{ext}") - } - # Note: nosec is fine here since we're always loading from a filesystem. - # Bandit throws an exception because it wants a revision number, - # which is only relevatn when loading from the Hub. - return load_dataset( - self.BUILDER, data_files=data_files, **self._load_args - ) # nosec - - result = load_dataset( # nosec + return load_dataset( # nosec self.BUILDER, data_files=load_path, storage_options=self._storage_options, **self._load_args, ) - # load_dataset wraps a single file in a DatasetDict with one - # split (typically "train"). When the caller didn't ask for a - # specific split, unwrap it so a single file round-trips as a - # Dataset, not a DatasetDict. - if ( - "split" not in self._load_args - and isinstance(result, DatasetDict) - and len(result) == 1 - ): - return next(iter(result.values())) - - return result - def _save_dataset(self, data: Dataset, save_path: str) -> None: saver = f"to_{self.BUILDER}" getattr(data, saver)( @@ -165,6 +138,10 @@ def _save_dataset(self, data: Dataset, save_path: str) -> None: ) def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: + """Hugging Face only provides ``DatasetDict.save_to_disk`` for Arrow format. + + As a result, we have to call ``to_`` per split. + """ self._fs.mkdirs(save_path, exist_ok=True) ext = self.EXTENSION saver = f"to_{self.BUILDER}" diff --git a/kedro-datasets/tests/huggingface/test_arrow_dataset.py b/kedro-datasets/tests/huggingface/test_arrow_dataset.py index 4198e86b8..b993f481a 100644 --- a/kedro-datasets/tests/huggingface/test_arrow_dataset.py +++ b/kedro-datasets/tests/huggingface/test_arrow_dataset.py @@ -61,27 +61,18 @@ def test_exists_dataset_dict(self, arrow_dataset, dataset_dict): assert arrow_dataset.exists() def test_save_and_load_iterable_dataset(self, arrow_dataset, iterable_dataset): - """Test saving an IterableDataset materializes and round-trips.""" - arrow_dataset.save(iterable_dataset) - reloaded = arrow_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } + """Test that saving an IterableDataset raises an error with a helpful message.""" + pattern = r"got iterable dataset" + with pytest.raises(DatasetError, match=pattern): + arrow_dataset.save(iterable_dataset) def test_save_and_load_iterable_dataset_dict( self, arrow_dataset, iterable_dataset_dict ): - """Test saving an IterableDatasetDict materializes and round-trips.""" - arrow_dataset.save(iterable_dataset_dict) - reloaded = arrow_dataset.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - assert reloaded["train"].to_dict() == { - "col1": [1, 2], - "col2": ["a", "b"], - } + """Test that saving an IterableDatasetDict raises an error with a helpful message.""" + pattern = r"got iterable dataset" + with pytest.raises(DatasetError, match=pattern): + arrow_dataset.save(iterable_dataset_dict) @pytest.mark.parametrize("save_args", [{"num_shards": 2}], indirect=True) def test_save_extra_params(self, arrow_dataset, save_args): @@ -103,7 +94,9 @@ def test_load_missing_dataset(self, arrow_dataset): def test_save_invalid_type(self, arrow_dataset): """Check the error when saving an unsupported type.""" - pattern = r"ArrowDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." + pattern = ( + r"ArrowDataset only supports .datasets.Dataset., .datasets.DatasetDict." + ) with pytest.raises(DatasetError, match=pattern): arrow_dataset.save({"not": "a dataset"}) @@ -180,23 +173,18 @@ def test_save_and_load_dataset_dict(self, versioned_arrow_dataset, dataset_dict) def test_save_and_load_iterable_dataset( self, versioned_arrow_dataset, iterable_dataset ): - """Test versioned save with IterableDataset.""" - versioned_arrow_dataset.save(iterable_dataset) - reloaded = versioned_arrow_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } + """Test that versioned save of IterableDataset raises an error with a helpful message.""" + pattern = r"got iterable dataset" + with pytest.raises(DatasetError, match=pattern): + versioned_arrow_dataset.save(iterable_dataset) def test_save_and_load_iterable_dataset_dict( self, versioned_arrow_dataset, iterable_dataset_dict ): - """Test versioned save with IterableDatasetDict.""" - versioned_arrow_dataset.save(iterable_dataset_dict) - reloaded = versioned_arrow_dataset.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} + """Test that versioned save of IterableDatasetDict raises an error with a helpful message.""" + pattern = r"got iterable dataset" + with pytest.raises(DatasetError, match=pattern): + versioned_arrow_dataset.save(iterable_dataset_dict) def test_no_versions(self, versioned_arrow_dataset): """Check the error if no versions are available for load.""" @@ -251,6 +239,8 @@ def test_http_filesystem_no_versioning(self): def test_save_invalid_type_versioned(self, versioned_arrow_dataset): """Check the error when saving an unsupported type through versioned dataset.""" - pattern = r"ArrowDataset only supports .datasets.Dataset., .datasets.DatasetDict., .datasets.IterableDataset., and .datasets.IterableDatasetDict. instances." + pattern = ( + r"ArrowDataset only supports .datasets.Dataset., .datasets.DatasetDict." + ) with pytest.raises(DatasetError, match=pattern): versioned_arrow_dataset.save("not a dataset") diff --git a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py index 159feec06..2a499965a 100644 --- a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py +++ b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py @@ -102,25 +102,14 @@ def test_save_and_load_dataset_dict(self, fs_dataset_dir, dataset_dict): assert reloaded[split].to_dict() == dataset_dict[split].to_dict() def test_save_and_load_iterable_dataset(self, fs_dataset, iterable_dataset): - fs_dataset.save(iterable_dataset) - reloaded = fs_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == { - "col1": [1, 2, 3], - "col2": ["a", "b", "c"], - } + with pytest.raises(DatasetError, match=r"got iterable dataset"): + fs_dataset.save(iterable_dataset) def test_save_and_load_iterable_dataset_dict( self, fs_dataset_dir, iterable_dataset_dict ): - fs_dataset_dir.save(iterable_dataset_dict) - reloaded = fs_dataset_dir.load() - assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - assert reloaded["train"].to_dict() == { - "col1": [1, 2], - "col2": ["a", "b"], - } + with pytest.raises(DatasetError, match=r"got iterable dataset"): + fs_dataset_dir.save(iterable_dataset_dict) def test_exists(self, fs_dataset, dataset): assert not fs_dataset.exists() From 3fdc4f79e5268e3298cd8797ebe483990cc4cbe9 Mon Sep 17 00:00:00 2001 From: "L. R. Couto" <57910428+lrcouto@users.noreply.github.com> Date: Fri, 17 Apr 2026 11:24:14 -0300 Subject: [PATCH 09/21] feat(datasets): Add `OpikEvaluationDataset` to experimental datasets (#1364) * Add OpikEvaluationDataset Signed-off-by: Laura Couto * Add unit tests Signed-off-by: Laura Couto * Lint Signed-off-by: Laura Couto * Lint Signed-off-by: Laura Couto * Docstring Signed-off-by: Laura Couto * Add OpikEvaluationDataset stuff to the readme Signed-off-by: Laura Couto * Add OpikEvaluationDataset Signed-off-by: Laura Couto * Add unit tests Signed-off-by: Laura Couto * Lint Signed-off-by: Laura Couto * Lint Signed-off-by: Laura Couto * Docstring Signed-off-by: Laura Couto * Add OpikEvaluationDataset stuff to the readme Signed-off-by: Laura Couto * Docs and release note Signed-off-by: Laura Couto * Typo Signed-off-by: Laura Couto * Update kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py Co-authored-by: Ravi Kumar Pilla Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> * Add more explicit errors in case of connection failure Signed-off-by: Laura Couto * Opik client flush to prevent async issues Signed-off-by: Laura Couto * Explicitly explain remote sync behavior Signed-off-by: Laura Couto * Fix release notes Signed-off-by: Laura Couto * Rephrase docstrings Signed-off-by: Laura Couto * Handle UUID more carefully Signed-off-by: Laura Couto * Wrap _client.flush() in try/except for DatasetError Signed-off-by: Laura Couto * Update README, add more explicit exception handling Signed-off-by: Laura Couto * Lint Signed-off-by: Laura Couto * Enforce UUIDv7 Signed-off-by: Laura Couto * Lint Signed-off-by: Laura Couto * Clarify interactions with UUIDv7 on docs Signed-off-by: Laura Couto * Extract auxiliary functions Signed-off-by: Laura Couto * Make it so 'non UUIDv7 creates a new row' is very explicit Signed-off-by: Laura Couto * Update kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py Co-authored-by: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> * Update kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py Co-authored-by: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> * Update kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py Co-authored-by: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> * Update kedro-datasets/kedro_datasets_experimental/opik/README.md Co-authored-by: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> * Update kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py Co-authored-by: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> * Update kedro-datasets/kedro_datasets_experimental/opik/README.md Co-authored-by: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> * Fix docstrings Signed-off-by: Laura Couto * Minor fix on docstring Signed-off-by: Laura Couto * Minor fix on docstring Signed-off-by: Laura Couto * Doc indent Signed-off-by: Laura Couto * Indent Signed-off-by: Laura Couto * Lint Signed-off-by: Laura Couto --------- Signed-off-by: Laura Couto Signed-off-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com> Co-authored-by: Ravi Kumar Pilla Co-authored-by: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Signed-off-by: iwhalen --- .../opik.OpikEvaluationDataset.md | 4 + .../opik/opik_evaluation_dataset.py | 604 +++++++++++++ .../opik/test_opik_evaluation_dataset.py | 791 ++++++++++++++++++ 3 files changed, 1399 insertions(+) create mode 100644 kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md create mode 100644 kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py create mode 100644 kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py diff --git a/kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md b/kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md new file mode 100644 index 000000000..7fd57c76b --- /dev/null +++ b/kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md @@ -0,0 +1,4 @@ +::: kedro_datasets_experimental.opik.OpikEvaluationDataset + options: + members: true + show_source: true diff --git a/kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py b/kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py new file mode 100644 index 000000000..f57c869b9 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py @@ -0,0 +1,604 @@ +import json +import logging +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +from kedro.io import AbstractDataset, DatasetError +from opik import Opik +from opik.api_objects.dataset.dataset import Dataset +from opik.rest_api.core.api_error import ApiError + +from kedro_datasets._typing import JSONPreview + +if TYPE_CHECKING: + from kedro_datasets.json import JSONDataset + from kedro_datasets.yaml import YAMLDataset + +logger = logging.getLogger(__name__) + +SUPPORTED_FILE_EXTENSIONS = {".json", ".yaml", ".yml"} +REQUIRED_OPIK_CREDENTIALS = {"api_key"} +OPTIONAL_OPIK_CREDENTIALS = {"workspace", "host", "project_name"} +VALID_SYNC_POLICIES = {"local", "remote"} +HTTP_NOT_FOUND = 404 +REQUIRED_UUID_VERSION = 7 + + +class OpikEvaluationDataset(AbstractDataset): + """Kedro dataset for Opik evaluation datasets. + + Connects to an Opik evaluation dataset and returns an ``opik.Dataset`` + on ``load()``, which can be passed to ``opik.evaluation.evaluate()`` to + run experiments. Supports an optional local JSON/YAML file as the + authoring surface for evaluation items. + + **On load / save behaviour:** + + - **On load:** Creates the remote dataset if it does not exist, + synchronises based on ``sync_policy``, and returns an ``opik.Dataset``. + - **On save:** Inserts all items to the remote dataset via Opik's + upsert-by-ID API. Items with a UUID v7 ``id`` update the existing + remote row in-place; items without a UUID v7 ``id`` create a new + remote row on every call. In ``local`` mode, items are also merged + into the local file (new items take precedence). In ``remote`` mode, + only the remote insert occurs. + + **Item format:** + + The local file and ``save()`` data must be a list of dicts. Each item + accepts the following keys: + + - ``input`` (**required**) — the evaluation input payload. + - ``id`` — identifier used for local deduplication. The upload + behaviour depends on whether ``id`` is a valid UUID v7: + + - **Valid UUID v7**: forwarded to Opik. Opik's API upserts by item + ID — the first sync creates the remote row; subsequent syncs + update that same row in-place if the content has changed. + The remote row keeps the same UUID across all syncs. Whenever + content changes, the existing remote row is updated in-place, + while no new row is created. + - **All other values** (human-readable strings, UUIDs of other + versions, ``None``, empty string, or no ``id`` key): stripped + before upload. Opik auto-generates a new UUID v7. Unchanged + content is deduplicated by content hash (no-op), but changed + content creates a **new remote row** while the previous one + remains, leading to row accumulation over time. + + - ``expected_output`` — ground-truth value for scoring. + - ``metadata`` — arbitrary metadata dict attached to the item. + + ```json + [ + { + "id": "q1", + "input": {"text": "cancel my order"}, + "expected_output": "cancel_order", + "metadata": {"source": "production"} + } + ] + ``` + ("q1" is used for local deduplication only, as it is not a UUID v7 and will be stripped on upload) + + **Sync policies:** + + - **local** (default): The local file is the source of truth. On + ``load()``, all local items are re-inserted to remote on every sync. + Opik's API upserts by item ID, so the outcome depends on whether + each item carries a UUID v7 ``id``: + + - Items with a UUID v7 ``id`` are updated in-place on the remote — + content changes replace the existing row; unchanged items are + a no-op. + - Items without a UUID v7 ``id`` (non-UUID values are stripped) + are deduplicated by content hash — unchanged content is a no-op, + but changed content creates a **new remote row** (the previous + row remains), leading to row accumulation over time. + ``save()`` inserts to remote and merges into the local file (new + data takes precedence). + + - **remote**: The remote Opik dataset is the sole source of truth. + ``load()`` fetches the remote dataset as-is with no local file + interaction. ``save()`` inserts all items to remote without writing + to any local file. If the remote dataset does not exist yet, it is + created empty — **no items are pushed from the local file**. To seed + a new remote dataset, run with ``sync_policy="local"`` at least once, + or create and populate the dataset directly via the Opik UI. + + Examples: + Using catalog YAML configuration: + + ```yaml + # Local sync policy — local file seeds and syncs to remote + evaluation_dataset: + type: kedro_datasets_experimental.opik.OpikEvaluationDataset + dataset_name: intent-detection-eval + filepath: data/evaluation/intent_items.json + sync_policy: local + credentials: opik_credentials + metadata: + project: intent-detection + + # Remote sync policy — Opik is the source of truth + production_eval: + type: kedro_datasets_experimental.opik.OpikEvaluationDataset + dataset_name: intent-detection-eval + sync_policy: remote + credentials: opik_credentials + ``` + + Using Python API: + + ```python + from kedro_datasets_experimental.opik import OpikEvaluationDataset + + dataset = OpikEvaluationDataset( + dataset_name="intent-detection-eval", + credentials={"api_key": "..."}, # pragma: allowlist secret + filepath="data/evaluation/intent_items.json", + ) + + # Load returns an opik.Dataset for running experiments + from opik.evaluation import evaluate + + eval_dataset = dataset.load() + evaluate( + dataset=eval_dataset, + task=my_task, + scoring_functions=[my_scorer], + experiment_name="my-experiment", + ) + + # Save new evaluation items + dataset.save( + [ + {"id": "q1", "input": {"text": "cancel order"}, "expected_output": "cancel"}, + ] + ) + + # Same as in the other example, "q1" is not a UUID v7 and will be stripped on upload + ``` + """ + + def __init__( + self, + dataset_name: str, + credentials: dict[str, str], + filepath: str | None = None, + sync_policy: Literal["local", "remote"] = "local", + metadata: dict[str, Any] | None = None, + ): + """Initialise ``OpikEvaluationDataset``. + + Args: + dataset_name: Name of the evaluation dataset in Opik. + credentials: Opik authentication credentials. + Required: ``api_key``. + Optional: ``workspace``, ``host``, ``project_name``. + filepath: Path to a local JSON/YAML file for authoring evaluation + items. Supports ``.json``, ``.yaml``, and ``.yml`` extensions. + When ``None``, no local file interaction occurs. + sync_policy: Controls the source of truth for reads and whether + a local file is involved: + ``"local"`` (default) — all local items are re-inserted to + remote on ``load()``; ``save()`` inserts to remote and + merges into the local file (new data takes precedence). + ``"remote"`` — ``load()`` fetches remote as-is; ``save()`` + inserts to remote without local file interaction. + metadata: Optional metadata dict stored locally and returned by + ``_describe()``. Note: Opik's ``create_dataset()`` does not + accept a metadata argument, so this value is not propagated + to the remote dataset. + """ + self._validate_init_params(credentials, filepath, sync_policy) + + self._dataset_name = dataset_name + self._filepath = Path(filepath) if filepath else None + self._sync_policy = sync_policy + self._metadata = metadata + self._file_dataset = None + + try: + self._client = Opik(**credentials) + except Exception as e: + raise DatasetError(f"Failed to initialise Opik client: {e}") from e + + @staticmethod + def _validate_init_params( + credentials: dict[str, str], + filepath: str | None, + sync_policy: str, + ) -> None: + OpikEvaluationDataset._validate_credentials(credentials) + OpikEvaluationDataset._validate_sync_policy(sync_policy) + OpikEvaluationDataset._validate_filepath(filepath) + + @staticmethod + def _validate_credentials(credentials: dict[str, str]) -> None: + for key in REQUIRED_OPIK_CREDENTIALS: + if key not in credentials: + raise DatasetError( + f"Missing required Opik credential: '{key}'." + ) + if not credentials[key] or not str(credentials[key]).strip(): + raise DatasetError( + f"Opik credential '{key}' cannot be empty." + ) + for key in OPTIONAL_OPIK_CREDENTIALS: + if key in credentials and ( + not credentials[key] or not str(credentials[key]).strip() + ): + raise DatasetError( + f"Opik credential '{key}' cannot be empty if provided." + ) + + @staticmethod + def _validate_sync_policy(sync_policy: str) -> None: + if sync_policy not in VALID_SYNC_POLICIES: + raise DatasetError( + f"Invalid sync_policy '{sync_policy}'. " + f"Must be one of: {', '.join(sorted(VALID_SYNC_POLICIES))}." + ) + + @staticmethod + def _validate_filepath(filepath: str | None) -> None: + if filepath is None: + return + suffix = Path(filepath).suffix.lower() + if suffix not in SUPPORTED_FILE_EXTENSIONS: + raise DatasetError( + f"Unsupported file extension '{suffix}'. " + f"Supported formats: {', '.join(sorted(SUPPORTED_FILE_EXTENSIONS))}." + ) + + @property + def file_dataset(self) -> "JSONDataset | YAMLDataset": + """Return a JSON or YAML file dataset based on the filepath extension.""" + if not self._filepath: + raise DatasetError("filepath must be provided for file dataset operations.") + if self._file_dataset is None: + suffix = self._filepath.suffix.lower() + if suffix in (".yaml", ".yml"): + from kedro_datasets.yaml import YAMLDataset # noqa: PLC0415 + self._file_dataset = YAMLDataset(filepath=str(self._filepath)) + else: + from kedro_datasets.json import JSONDataset # noqa: PLC0415 + self._file_dataset = JSONDataset(filepath=str(self._filepath)) + return self._file_dataset + + def _get_or_create_remote_dataset(self) -> Dataset: + """Ensure the remote Opik dataset exists, creating it if not found. + + Returns the latest ``Dataset`` object. + + Raises: + DatasetError: If the Opik API returns an unexpected error or is + unreachable. + """ + try: + return self._client.get_dataset(name=self._dataset_name) + except ApiError as e: + if e.status_code != HTTP_NOT_FOUND: + raise DatasetError( + f"Opik API error while fetching dataset '{self._dataset_name}': {e}" + ) from e + except Exception as e: + raise DatasetError( + f"Failed to connect to Opik while fetching dataset " + f"'{self._dataset_name}': {e}" + ) from e + + try: + logger.info( + "Dataset '%s' not found on Opik, creating it.", + self._dataset_name, + ) + return self._client.create_dataset( + name=self._dataset_name, + description=f"Created by Kedro (sync_policy={self._sync_policy})", + ) + except ApiError as e: + raise DatasetError( + f"Opik API error while creating dataset '{self._dataset_name}': {e}" + ) from e + except Exception as e: + raise DatasetError( + f"Failed to connect to Opik while creating dataset " + f"'{self._dataset_name}': {e}" + ) from e + + @staticmethod + def _strip_id(item: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in item.items() if k != "id"} + + @staticmethod + def _validate_items(items: list[dict[str, Any]]) -> None: + """Validate that all items contain the required ``input`` key. + + Raises: + DatasetError: If any item is missing the ``input`` key. + """ + for i, item in enumerate(items): + if "input" not in item: + raise DatasetError( + f"Dataset item at index {i} is missing required 'input' key." + ) + + def _upload_items(self, dataset: Dataset, items: list[dict[str, Any]]) -> None: + """Insert items into the remote Opik dataset. + + Upload behaviour depends on whether an item carries a UUID v7 ``id``: + + - **Valid UUID v7**: forwarded to Opik. Opik's REST API calls + ``create_or_update`` by item ID — the first call creates the + remote row; subsequent calls update that same row in-place if + the content has changed. Whenever content changes, the existing + remote row is updated in-place, while no new row is created. + - **All other values** (human-readable strings, UUIDs of other + versions, ``None``, empty string, or no ``id`` key): stripped + before upload. Opik auto-generates a new UUID v7. Unchanged + content is deduplicated by content hash (no-op), but changed + content creates a **new remote row** while the previous one + remains. + + Callers are responsible for validating items before calling this method. + + Raises: + DatasetError: If the Opik API returns an error or the server is + unreachable during insert. + """ + items_to_insert = [] + for item in items: + if "id" not in item: + items_to_insert.append(item) + elif not item["id"]: + items_to_insert.append(self._strip_id(item)) + else: + try: + parsed = uuid.UUID(str(item["id"])) + if parsed.version == REQUIRED_UUID_VERSION: + items_to_insert.append(item) # valid UUID v7 — preserve id + else: + items_to_insert.append(self._strip_id(item)) + except ValueError: + items_to_insert.append(self._strip_id(item)) + try: + dataset.insert(items_to_insert) + except ApiError as e: + raise DatasetError( + f"Opik API error while inserting items into dataset " + f"'{self._dataset_name}': {e}" + ) from e + except Exception as e: + raise DatasetError( + f"Failed to insert items into Opik dataset '{self._dataset_name}': {e}" + ) from e + + def _sync_local_to_remote(self, dataset: Dataset) -> Dataset: + """Insert all local items into the remote dataset. + + Reads the local file and inserts all items into the remote dataset. + The Opik SDK deduplicates by content hash, so re-inserting unchanged + items is a no-op. Returns a refreshed ``Dataset`` object. If the dataset's + id is a valid UUID v7, the same remote row is updated in-place on every sync. + Otherwise, a new remote row will be created. + """ + if not self._filepath or not self._filepath.exists(): + return dataset + + local_items = self.file_dataset.load() + self._validate_items(local_items) + + if not local_items: + return dataset + + items_without_stable_id = [ + item for item in local_items + if "id" not in item or not item.get("id") + ] + if items_without_stable_id: + logger.warning( + "Found %d item(s) with a missing, None, or empty 'id' field in '%s'. " + "These cannot be tracked across syncs and will create new remote " + "rows on every load.", + len(items_without_stable_id), + self._filepath, + ) + + items_with_non_uuid_v7_id = [] + for item in local_items: + if item.get("id"): # present and non-empty/non-None + try: + parsed = uuid.UUID(str(item["id"])) + if parsed.version != REQUIRED_UUID_VERSION: + items_with_non_uuid_v7_id.append(item) + except ValueError: + items_with_non_uuid_v7_id.append(item) + if items_with_non_uuid_v7_id: + logger.warning( + "Found %d item(s) with non-UUID-v7 'id' values in '%s' " + "(e.g. '%s'). Opik requires UUID v7 for item IDs — these " + "will be stripped before upload and Opik will auto-generate " + "UUID v7 values. Remote rows will not have stable identities.", + len(items_with_non_uuid_v7_id), + self._filepath, + items_with_non_uuid_v7_id[0]["id"], + ) + + logger.info( + "Syncing %d item(s) from '%s' to remote dataset '%s'.", + len(local_items), + self._filepath, + self._dataset_name, + ) + self._upload_items(dataset, local_items) + try: + self._client.flush() + except Exception as e: + raise DatasetError( + f"Failed to flush items to Opik dataset '{self._dataset_name}': {e}" + ) from e + + try: + return self._client.get_dataset(name=self._dataset_name) + except ApiError as e: + raise DatasetError( + f"Opik API error while refreshing dataset '{self._dataset_name}' after sync: {e}" + ) from e + except Exception as e: + raise DatasetError( + f"Failed to refresh dataset '{self._dataset_name}' after sync: {e}" + ) from e + + @staticmethod + def _merge_items( + existing: list[dict[str, Any]], + new: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Merge new items into an existing list, deduplicating by ``id``. + + Items without an ``id`` key are always appended. For items with an + ``id``, new items take precedence — existing entries with the same + ``id`` are replaced in place. + """ + new_by_id: dict[str, dict[str, Any]] = { + item["id"]: item for item in new if "id" in item + } + + seen_ids: set[str] = set() + merged: list[dict[str, Any]] = [] + + for item in existing: + item_id = item.get("id") + if item_id is not None and item_id in new_by_id: + merged.append(new_by_id[item_id]) + seen_ids.add(item_id) + else: + merged.append(item) + if item_id is not None: + seen_ids.add(item_id) + + for item in new: + item_id = item.get("id") + if item_id is not None and item_id in seen_ids: + continue + if item_id is not None: + seen_ids.add(item_id) + merged.append(item) + + return merged + + def load(self) -> Dataset: + """Load the Opik dataset, syncing local items to remote if sync_policy is ``local``. + + Creates the remote dataset if it does not exist. In ``local`` mode, all + local items are re-inserted to remote on every load via Opik's + ``create_or_update`` API (upsert by item ID). On items with a valid UUID v7 + ``id``, update the existing remote row in-place, and no new row is created. + On items where the ``id`` is not a valid UUID v7 (including missing, ``None``, or empty), + the ``id`` is stripped before upload and Opik auto-generates a new UUID v7. + Unchanged content is deduplicated (no-op), but changed content creates a + new remote row while the previous one remains. + + Returns: + Dataset: The Opik dataset ready for use in experiments. + + Raises: + DatasetError: If the Opik API returns an unexpected error or the + server is unreachable. + """ + dataset = self._get_or_create_remote_dataset() + + if self._sync_policy == "local": + dataset = self._sync_local_to_remote(dataset) + + logger.info( + "Loaded dataset '%s' (sync_policy='%s').", + self._dataset_name, + self._sync_policy, + ) + return dataset + + def save(self, data: list[dict[str, Any]]) -> None: + """Insert items into the Opik dataset and optionally update the local file. + + In ``remote`` mode, only the remote upload occurs. In ``local`` mode, + items are also merged into the local file. + + Args: + data: List of dicts, each containing at least an ``input`` key. + + Raises: + DatasetError: If the Opik API call fails or any item is missing ``input``. + """ + if self._sync_policy == "remote": + logger.warning( + "sync_policy='remote': save() uploads to remote only, " + "local file '%s' will not be updated.", + self._filepath, + ) + + self._validate_items(data) + + dataset = self._get_or_create_remote_dataset() + self._upload_items(dataset, data) + try: + self._client.flush() + except Exception as e: + raise DatasetError( + f"Failed to flush items to Opik dataset '{self._dataset_name}': {e}" + ) from e + + if self._sync_policy == "local" and self._filepath: + existing: list[dict] = [] + if self._filepath.exists(): + existing = self.file_dataset.load() + self.file_dataset.save(self._merge_items(existing, data)) + + def _exists(self) -> bool: + try: + self._client.get_dataset(name=self._dataset_name) + return True + except ApiError as e: + if e.status_code == HTTP_NOT_FOUND: + return False + raise DatasetError( + f"Opik API error while checking dataset '{self._dataset_name}': {e}" + ) from e + except Exception as e: + raise DatasetError( + f"Failed to connect to Opik while checking dataset " + f"'{self._dataset_name}': {e}" + ) from e + + def _describe(self) -> dict[str, Any]: + return { + "dataset_name": self._dataset_name, + "filepath": str(self._filepath) if self._filepath else None, + "sync_policy": self._sync_policy, + "metadata": self._metadata, + } + + def preview(self) -> JSONPreview: + """Generate a JSON-compatible preview of the local evaluation data for Kedro-Viz. + + Returns: + JSONPreview: A Kedro-Viz-compatible object containing a serialized JSON string. + Returns a descriptive message if filepath is not configured or does not exist. + """ + if not self._filepath: + return JSONPreview("No filepath configured.") + + if not self._filepath.exists(): + return JSONPreview("Local evaluation dataset does not exist.") + + local_data = self.file_dataset.load() + + if isinstance(local_data, str): + local_data = {"content": local_data} + + try: + return JSONPreview(json.dumps(local_data)) + except (TypeError, ValueError) as e: + return JSONPreview(f"Could not serialise local data to JSON: {e}") diff --git a/kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py new file mode 100644 index 000000000..a2de5639a --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py @@ -0,0 +1,791 @@ +import datetime +import json +from unittest.mock import Mock, patch + +import pytest +import yaml +from kedro.io import DatasetError +from opik.rest_api.core.api_error import ApiError + +from kedro_datasets_experimental.opik.opik_evaluation_dataset import ( + OpikEvaluationDataset, +) + + +def make_api_error(status_code: int) -> ApiError: + """Return an ApiError with the given status code.""" + return ApiError(status_code=status_code, headers={}, body={}) + + +@pytest.fixture +def mock_opik(): + """Mock Opik client instance.""" + with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.Opik") as mock_class: + instance = Mock() + mock_class.return_value = instance + yield instance + + +@pytest.fixture +def mock_credentials(): + """Valid Opik credentials for testing.""" + return { + "api_key": "opik_test_key", # pragma: allowlist secret + "workspace": "test-workspace", + } + + +@pytest.fixture +def eval_items(): + """Sample evaluation dataset items with human-readable (non-UUID) IDs.""" + return [ + { + "id": "item_001", + "input": {"question": "What is AI?"}, + "expected_output": {"answer": "Artificial Intelligence"}, + }, + { + "id": "item_002", + "input": {"question": "What is ML?"}, + "expected_output": {"answer": "Machine Learning"}, + }, + ] + + +@pytest.fixture +def eval_items_uuid(): + """Sample evaluation dataset items with valid UUID v7 IDs.""" + return [ + { + "id": "018e2f1a-dead-7abc-8def-000000000001", + "input": {"question": "What is AI?"}, + "expected_output": {"answer": "Artificial Intelligence"}, + }, + { + "id": "018e2f1a-dead-7abc-8def-000000000002", + "input": {"question": "What is ML?"}, + "expected_output": {"answer": "Machine Learning"}, + }, + ] + + +@pytest.fixture +def eval_items_mixed(): + """Items with a mix of UUID v7 and human-readable IDs.""" + return [ + { + "id": "018e2f1a-dead-7abc-8def-000000000001", + "input": {"question": "What is AI?"}, + "expected_output": {"answer": "Artificial Intelligence"}, + }, + { + "id": "human_readable_id", + "input": {"question": "What is ML?"}, + "expected_output": {"answer": "Machine Learning"}, + }, + ] + + +@pytest.fixture +def eval_items_no_id(): + """Evaluation items without IDs.""" + return [ + {"input": {"question": "What is AI?"}, "expected_output": {"answer": "AI"}}, + {"input": {"question": "What is ML?"}, "expected_output": {"answer": "ML"}}, + ] + + +@pytest.fixture +def filepath_json(tmp_path, eval_items): + """Temporary JSON file with evaluation items.""" + filepath = tmp_path / "eval.json" + filepath.write_text(json.dumps(eval_items)) + return str(filepath) + + +@pytest.fixture +def filepath_yaml(tmp_path, eval_items): + """Temporary YAML file with evaluation items.""" + filepath = tmp_path / "eval.yaml" + filepath.write_text(yaml.dump(eval_items)) + return str(filepath) + + +@pytest.fixture +def mock_remote_dataset(): + """Mock Opik Dataset object.""" + ds = Mock() + ds.name = "test-dataset" + return ds + + +@pytest.fixture +def dataset_local(filepath_json, mock_credentials, mock_opik, mock_remote_dataset): + """OpikEvaluationDataset with local sync policy.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + return OpikEvaluationDataset( + dataset_name="test-dataset", + credentials=mock_credentials, + filepath=filepath_json, + sync_policy="local", + ) + + +@pytest.fixture +def dataset_remote(mock_credentials, mock_opik, mock_remote_dataset): + """OpikEvaluationDataset with remote sync policy and no filepath.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + return OpikEvaluationDataset( + dataset_name="test-dataset", + credentials=mock_credentials, + sync_policy="remote", + ) + + +class TestOpikEvaluationDatasetInit: + """Test OpikEvaluationDataset initialisation.""" + + def test_init_minimal_params(self, mock_credentials, mock_opik): + """Minimal required params store expected defaults.""" + ds = OpikEvaluationDataset( + dataset_name="my-dataset", + credentials=mock_credentials, + ) + assert ds._dataset_name == "my-dataset" + assert ds._filepath is None + assert ds._sync_policy == "local" + assert ds._metadata is None + + def test_init_all_params(self, filepath_json, mock_credentials, mock_opik): + """All params are stored correctly.""" + meta = {"project": "test"} + ds = OpikEvaluationDataset( + dataset_name="my-dataset", + credentials=mock_credentials, + filepath=filepath_json, + sync_policy="remote", + metadata=meta, + ) + assert ds._sync_policy == "remote" + assert ds._metadata == meta + assert ds._filepath is not None + + def test_init_missing_api_key(self, mock_opik): + """Missing api_key raises DatasetError.""" + with pytest.raises(DatasetError, match="Missing required Opik credential: 'api_key'"): + OpikEvaluationDataset( + dataset_name="ds", + credentials={"workspace": "w"}, + ) + + @pytest.mark.parametrize("empty_value", ["", " "]) + def test_init_empty_api_key(self, mock_opik, empty_value): + """Empty api_key raises DatasetError.""" + with pytest.raises(DatasetError, match="Opik credential 'api_key' cannot be empty"): + OpikEvaluationDataset( + dataset_name="ds", + credentials={"api_key": empty_value}, + ) + + def test_init_empty_optional_credential(self, mock_opik): + """Empty optional credential (workspace) raises DatasetError.""" + with pytest.raises(DatasetError, match="Opik credential 'workspace' cannot be empty if provided"): + OpikEvaluationDataset( + dataset_name="ds", + credentials={"api_key": "key", "workspace": ""}, # pragma: allowlist secret + ) + + def test_init_invalid_sync_policy(self, mock_credentials, mock_opik): + """Invalid sync_policy raises DatasetError.""" + with pytest.raises(DatasetError, match="Invalid sync_policy 'invalid'"): + OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + sync_policy="invalid", + ) + + def test_init_unsupported_filepath_extension(self, tmp_path, mock_credentials, mock_opik): + """Unsupported file extension raises DatasetError.""" + bad_file = tmp_path / "data.txt" + bad_file.write_text("content") + with pytest.raises(DatasetError, match="Unsupported file extension '.txt'"): + OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=str(bad_file), + ) + + def test_init_client_failure_raises_dataset_error(self, mock_credentials): + """Opik client construction failure is wrapped in DatasetError.""" + with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.Opik") as mock_class: + mock_class.side_effect = Exception("Connection refused") + with pytest.raises(DatasetError, match="Failed to initialise Opik client"): + OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + ) + + +class TestFiledatasetProperty: + """Test the file_dataset lazy property.""" + + def test_json_returns_json_dataset(self, dataset_local): + """JSON filepath resolves to JSONDataset.""" + assert dataset_local.file_dataset.__class__.__name__ == "JSONDataset" + + def test_yaml_returns_yaml_dataset(self, filepath_yaml, mock_credentials, mock_opik, mock_remote_dataset): + """YAML filepath resolves to YAMLDataset.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=filepath_yaml, + ) + assert ds.file_dataset.__class__.__name__ == "YAMLDataset" + + def test_is_cached(self, dataset_local): + """Repeated access returns the same object.""" + assert dataset_local.file_dataset is dataset_local.file_dataset + + def test_no_filepath_raises(self, dataset_remote): + """Accessing file_dataset without a filepath raises DatasetError.""" + with pytest.raises(DatasetError, match="filepath must be provided"): + _ = dataset_remote.file_dataset + + +class TestGetOrCreateRemoteDataset: + """Test the _get_or_create_remote_dataset helper.""" + + def test_returns_existing_dataset(self, dataset_local, mock_opik, mock_remote_dataset): + """Returns the dataset when it already exists.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + result = dataset_local._get_or_create_remote_dataset() + assert result is mock_remote_dataset + mock_opik.create_dataset.assert_not_called() + + def test_creates_dataset_on_404(self, dataset_local, mock_opik, mock_remote_dataset): + """Creates a new dataset when get_dataset raises 404.""" + mock_opik.get_dataset.side_effect = make_api_error(404) + mock_opik.create_dataset.return_value = mock_remote_dataset + + result = dataset_local._get_or_create_remote_dataset() + + mock_opik.create_dataset.assert_called_once() + assert result is mock_remote_dataset + + def test_non_404_api_error_raises_dataset_error(self, dataset_local, mock_opik): + """Non-404 API error from get_dataset is wrapped in DatasetError.""" + mock_opik.get_dataset.side_effect = make_api_error(500) + with pytest.raises(DatasetError, match="Opik API error while fetching dataset"): + dataset_local._get_or_create_remote_dataset() + + def test_create_dataset_api_error_raises_dataset_error(self, dataset_local, mock_opik): + """API error during create_dataset is wrapped in DatasetError.""" + mock_opik.get_dataset.side_effect = make_api_error(404) + mock_opik.create_dataset.side_effect = make_api_error(400) + with pytest.raises(DatasetError, match="Opik API error while creating dataset"): + dataset_local._get_or_create_remote_dataset() + + def test_connection_error_on_get_raises_dataset_error(self, dataset_local, mock_opik): + """Non-ApiError on get_dataset (e.g. connection refused) is wrapped in DatasetError.""" + mock_opik.get_dataset.side_effect = ConnectionRefusedError("Connection refused") + with pytest.raises(DatasetError, match="Failed to connect to Opik"): + dataset_local._get_or_create_remote_dataset() + + def test_connection_error_on_create_raises_dataset_error(self, dataset_local, mock_opik): + """Non-ApiError on create_dataset (e.g. connection refused) is wrapped in DatasetError.""" + mock_opik.get_dataset.side_effect = make_api_error(404) + mock_opik.create_dataset.side_effect = ConnectionRefusedError("Connection refused") + with pytest.raises(DatasetError, match="Failed to connect to Opik"): + dataset_local._get_or_create_remote_dataset() + + + +class TestValidateItems: + """Test the _validate_items static method.""" + + def test_valid_items_pass(self, eval_items): + """Items with 'input' keys pass validation without error.""" + OpikEvaluationDataset._validate_items(eval_items) # no exception + + def test_empty_list_passes(self): + """Empty item list is valid.""" + OpikEvaluationDataset._validate_items([]) + + def test_missing_input_raises_dataset_error(self): + """Item missing 'input' raises DatasetError with index.""" + items = [{"input": {"q": "ok"}}, {"expected_output": "missing input"}] + with pytest.raises(DatasetError, match="index 1.*missing required 'input'"): + OpikEvaluationDataset._validate_items(items) + + +class TestUploadItems: + """Test the _upload_items method.""" + + @pytest.mark.parametrize("bad_id,label", [ + ("item_001", "human-readable"), + ("550e8400-e29b-41d4-a716-446655440000", "UUID v4"), + ]) + def test_non_uuid_v7_id_is_stripped(self, dataset_local, mock_remote_dataset, bad_id, label): + """Non-UUID-v7 IDs (human-readable or other UUID versions) are stripped before upload.""" + items = [{"id": bad_id, "input": {"question": "What is AI?"}}] + dataset_local._upload_items(mock_remote_dataset, items) + + inserted = mock_remote_dataset.insert.call_args[0][0] + assert "id" not in inserted[0] + + def test_non_id_fields_are_preserved(self, dataset_local, mock_remote_dataset, eval_items): + """input and expected_output fields are passed through unchanged.""" + dataset_local._upload_items(mock_remote_dataset, eval_items) + + inserted = mock_remote_dataset.insert.call_args[0][0] + assert inserted[0]["input"] == eval_items[0]["input"] + assert inserted[0]["expected_output"] == eval_items[0]["expected_output"] + + def test_valid_uuidv7_ids_are_preserved(self, dataset_local, mock_remote_dataset, eval_items_uuid): + """Valid UUID IDs are forwarded to Opik unchanged.""" + dataset_local._upload_items(mock_remote_dataset, eval_items_uuid) + + inserted = mock_remote_dataset.insert.call_args[0][0] + assert inserted[0]["id"] == eval_items_uuid[0]["id"] + assert inserted[1]["id"] == eval_items_uuid[1]["id"] + + def test_mixed_ids_uuid_preserved_non_uuid_stripped( + self, dataset_local, mock_remote_dataset, eval_items_mixed + ): + """UUID IDs are preserved; human-readable IDs are stripped in the same batch.""" + dataset_local._upload_items(mock_remote_dataset, eval_items_mixed) + + inserted = mock_remote_dataset.insert.call_args[0][0] + assert inserted[0]["id"] == eval_items_mixed[0]["id"] # UUID preserved + assert "id" not in inserted[1] # non-UUID stripped + + def test_uuid_v7_id_preserved_when_content_changes(self, dataset_local, mock_remote_dataset): + """A UUID v7 id is forwarded on both uploads even when item content changes.""" + uuid_v7_id = "018e2f1a-dead-7abc-8def-000000000001" + first_version = [{"id": uuid_v7_id, "input": {"question": "What is AI?"}}] + second_version = [{"id": uuid_v7_id, "input": {"question": "What is Artificial Intelligence?"}}] + + dataset_local._upload_items(mock_remote_dataset, first_version) + assert mock_remote_dataset.insert.call_args[0][0][0]["id"] == uuid_v7_id + + dataset_local._upload_items(mock_remote_dataset, second_version) + assert mock_remote_dataset.insert.call_args[0][0][0]["id"] == uuid_v7_id + + def test_items_without_id_are_passed_unchanged(self, dataset_local, mock_remote_dataset, eval_items_no_id): + """Items that have no 'id' key are inserted as-is.""" + dataset_local._upload_items(mock_remote_dataset, eval_items_no_id) + + inserted = mock_remote_dataset.insert.call_args[0][0] + assert inserted == eval_items_no_id + + @pytest.mark.parametrize("bad_id", [None, ""]) + def test_none_or_empty_id_is_stripped(self, dataset_local, mock_remote_dataset, bad_id): + """Items with id=None or id='' have the id key stripped before upload.""" + items = [{"id": bad_id, "input": {"question": "What is AI?"}}] + dataset_local._upload_items(mock_remote_dataset, items) + + inserted = mock_remote_dataset.insert.call_args[0][0] + assert "id" not in inserted[0] + + @pytest.mark.parametrize("error,match", [ + (make_api_error(500), "Opik API error while inserting items"), + (ConnectionRefusedError("Connection refused"), "Failed to insert items into Opik dataset"), + ]) + def test_insert_error_raises_dataset_error(self, dataset_local, mock_remote_dataset, eval_items, error, match): + """SDK errors from dataset.insert() are wrapped in DatasetError.""" + mock_remote_dataset.insert.side_effect = error + with pytest.raises(DatasetError, match=match): + dataset_local._upload_items(mock_remote_dataset, eval_items) + + +class TestSyncLocalToRemote: + """Test the _sync_local_to_remote helper.""" + + def test_returns_dataset_unchanged_when_no_filepath(self, dataset_remote, mock_remote_dataset): + """No-op when filepath is not configured.""" + result = dataset_remote._sync_local_to_remote(mock_remote_dataset) + assert result is mock_remote_dataset + + def test_returns_dataset_unchanged_when_file_missing(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): + """No-op when local file does not exist.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=str(tmp_path / "nonexistent.json"), + ) + result = ds._sync_local_to_remote(mock_remote_dataset) + assert result is mock_remote_dataset + + def test_returns_dataset_unchanged_for_empty_file(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): + """No-op when local file contains an empty list.""" + empty_file = tmp_path / "empty.json" + empty_file.write_text("[]") + mock_opik.get_dataset.return_value = mock_remote_dataset + + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=str(empty_file), + ) + result = ds._sync_local_to_remote(mock_remote_dataset) + assert result is mock_remote_dataset + mock_remote_dataset.insert.assert_not_called() + + def test_calls_upload_items(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): + """Loads local items and passes them to _upload_items.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + + with patch.object(dataset_local, "_upload_items") as mock_upload: + dataset_local._sync_local_to_remote(mock_remote_dataset) + mock_upload.assert_called_once_with(mock_remote_dataset, eval_items) + + def test_returns_refreshed_dataset(self, dataset_local, mock_opik, mock_remote_dataset): + """Returns the result of a fresh get_dataset call after upload.""" + refreshed = Mock() + mock_opik.get_dataset.return_value = refreshed + + with patch.object(dataset_local, "_upload_items"): + result = dataset_local._sync_local_to_remote(mock_remote_dataset) + + assert result is refreshed + + def test_flushes_client_after_upload(self, dataset_local, mock_opik, mock_remote_dataset): + """Calls client.flush() after insert to ensure items are committed before evaluate().""" + mock_opik.get_dataset.return_value = mock_remote_dataset + + with patch.object(dataset_local, "_upload_items"): + dataset_local._sync_local_to_remote(mock_remote_dataset) + + mock_opik.flush.assert_called_once() + + def test_flush_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): + """Errors from client.flush() during sync are wrapped in DatasetError.""" + mock_opik.flush.side_effect = Exception("flush failed") + mock_opik.get_dataset.return_value = mock_remote_dataset + + with patch.object(dataset_local, "_upload_items"): + with pytest.raises(DatasetError, match="Failed to flush items"): + dataset_local._sync_local_to_remote(mock_remote_dataset) + + def test_refresh_api_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): + """ApiError from get_dataset() after sync is wrapped in DatasetError.""" + mock_opik.flush.return_value = None + mock_opik.get_dataset.side_effect = make_api_error(500) + + with patch.object(dataset_local, "_upload_items"): + with pytest.raises(DatasetError, match="Opik API error while refreshing dataset"): + dataset_local._sync_local_to_remote(mock_remote_dataset) + + def test_refresh_connection_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): + """Connection error from get_dataset() after sync is wrapped in DatasetError.""" + mock_opik.flush.return_value = None + mock_opik.get_dataset.side_effect = ConnectionRefusedError("Connection refused") + + with patch.object(dataset_local, "_upload_items"): + with pytest.raises(DatasetError, match="Failed to refresh dataset"): + dataset_local._sync_local_to_remote(mock_remote_dataset) + + @pytest.mark.parametrize("items", [ + [{"input": {"question": "What is AI?"}}], + [{"id": None, "input": {"question": "What is AI?"}}], + [{"id": "", "input": {"question": "What is AI?"}}], + ]) + def test_warns_when_id_missing_or_empty( + self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset, items + ): + """Logs a warning when items have no id, id=None, or id=''.""" + filepath = tmp_path / "eval.json" + filepath.write_text(json.dumps(items)) + mock_opik.get_dataset.return_value = mock_remote_dataset + + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=str(filepath), + ) + + with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.logger") as mock_logger: + with patch.object(ds, "_upload_items"): + ds._sync_local_to_remote(mock_remote_dataset) + warning_messages = [c[0][0] for c in mock_logger.warning.call_args_list] + assert any("missing, None, or empty" in msg for msg in warning_messages) + + +class TestMergeItems: + """Test the _merge_items static method.""" + + def test_new_item_replaces_existing_by_id(self): + """New item with existing ID replaces the old entry in place.""" + existing = [{"id": "a", "input": {"v": 1}}, {"id": "b", "input": {"v": 2}}] + new = [{"id": "a", "input": {"v": 99}}] + result = OpikEvaluationDataset._merge_items(existing, new) + assert result[0]["input"]["v"] == 99 + assert len(result) == 2 + + def test_new_item_without_id_is_appended(self): + """New item without ID is always appended, never deduped.""" + existing = [{"id": "a", "input": {"v": 1}}] + new = [{"input": {"v": 2}}] + result = OpikEvaluationDataset._merge_items(existing, new) + assert len(result) == 2 + assert result[1]["input"]["v"] == 2 + + def test_new_item_with_new_id_is_appended(self): + """New item with a novel ID is appended after existing items.""" + existing = [{"id": "a", "input": {"v": 1}}] + new = [{"id": "b", "input": {"v": 2}}] + result = OpikEvaluationDataset._merge_items(existing, new) + assert len(result) == 2 + assert result[1]["id"] == "b" + + def test_empty_existing_returns_new(self): + """Merging into empty list returns a copy of new items.""" + new = [{"id": "a", "input": {"v": 1}}] + result = OpikEvaluationDataset._merge_items([], new) + assert result == new + + def test_empty_new_returns_existing(self): + """Merging empty new list returns existing unchanged.""" + existing = [{"id": "a", "input": {"v": 1}}] + result = OpikEvaluationDataset._merge_items(existing, []) + assert result == existing + + def test_order_preserved_with_replacement(self): + """Replacement keeps the item at its original position.""" + existing = [{"id": "a", "input": {"v": 1}}, {"id": "b", "input": {"v": 2}}] + new = [{"id": "b", "input": {"v": 99}}] + result = OpikEvaluationDataset._merge_items(existing, new) + assert result[0]["id"] == "a" + assert result[1]["input"]["v"] == 99 + + def test_duplicate_no_id_items_both_appended(self): + """Two new items without ID are both appended (no dedup possible).""" + existing = [] + new = [{"input": {"v": 1}}, {"input": {"v": 1}}] + result = OpikEvaluationDataset._merge_items(existing, new) + assert len(result) == 2 + + +class TestLoad: + """Test the load() method.""" + + def test_load_remote_mode_returns_dataset(self, dataset_remote, mock_opik, mock_remote_dataset): + """Remote mode fetches and returns the dataset without syncing.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + result = dataset_remote.load() + assert result is mock_remote_dataset + + def test_load_remote_mode_does_not_sync(self, dataset_remote, mock_opik, mock_remote_dataset): + """Remote mode does not call _sync_local_to_remote.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + with patch.object(dataset_remote, "_sync_local_to_remote") as mock_sync: + dataset_remote.load() + mock_sync.assert_not_called() + + def test_load_local_mode_calls_sync(self, dataset_local, mock_opik, mock_remote_dataset): + """Local mode calls _sync_local_to_remote.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + with patch.object(dataset_local, "_sync_local_to_remote", return_value=mock_remote_dataset) as mock_sync: + dataset_local.load() + mock_sync.assert_called_once_with(mock_remote_dataset) + + def test_load_creates_dataset_if_missing(self, dataset_local, mock_opik, mock_remote_dataset): + """Creates the remote dataset if it does not exist.""" + mock_opik.get_dataset.side_effect = [make_api_error(404), mock_remote_dataset] + mock_opik.create_dataset.return_value = mock_remote_dataset + + with patch.object(dataset_local, "_sync_local_to_remote", return_value=mock_remote_dataset): + result = dataset_local.load() + + mock_opik.create_dataset.assert_called_once() + assert result is mock_remote_dataset + + def test_load_api_error_raises_dataset_error(self, dataset_local, mock_opik): + """Non-404 API error from load is wrapped in DatasetError.""" + mock_opik.get_dataset.side_effect = make_api_error(503) + with pytest.raises(DatasetError, match="Opik API error while fetching dataset"): + dataset_local.load() + + +class TestSave: + """Test the save() method.""" + + def test_save_local_mode_uploads_to_remote(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): + """Local mode uploads items to the remote dataset.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + dataset_local.save(eval_items) + mock_remote_dataset.insert.assert_called_once() + + def test_save_flushes_client_after_upload(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): + """Calls client.flush() after insert to ensure items are committed before evaluate().""" + mock_opik.get_dataset.return_value = mock_remote_dataset + dataset_local.save(eval_items) + mock_opik.flush.assert_called_once() + + def test_save_flush_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): + """Errors from client.flush() during save are wrapped in DatasetError.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + mock_opik.flush.side_effect = Exception("flush failed") + + with pytest.raises(DatasetError, match="Failed to flush items"): + dataset_local.save(eval_items) + + def test_save_local_mode_merges_into_file(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): + """Local mode merges new items into the local file.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + new_item = [{"id": "item_003", "input": {"question": "What is DL?"}}] + dataset_local.save(new_item) + + written = json.loads(dataset_local._filepath.read_text()) + ids = [i.get("id") for i in written] + assert "item_003" in ids + + def test_save_local_mode_replaces_existing_id(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): + """Local mode replaces an existing item when IDs match.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + updated = [{"id": "item_001", "input": {"question": "Updated?"}}] + dataset_local.save(updated) + + written = json.loads(dataset_local._filepath.read_text()) + item_001 = next(i for i in written if i.get("id") == "item_001") + assert item_001["input"]["question"] == "Updated?" + + def test_save_local_mode_creates_file_if_missing(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): + """Creates the local file if it does not exist yet.""" + missing = tmp_path / "new.json" + mock_opik.get_dataset.return_value = mock_remote_dataset + + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=str(missing), + sync_policy="local", + ) + ds.save([{"id": "x", "input": {"q": "hello"}}]) + assert missing.exists() + + def test_save_remote_mode_uploads_to_remote(self, dataset_remote, mock_opik, mock_remote_dataset, eval_items): + """Remote mode uploads items to the remote dataset.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + dataset_remote.save(eval_items) + mock_remote_dataset.insert.assert_called_once() + + def test_save_remote_mode_does_not_write_local_file(self, dataset_remote, mock_opik, mock_remote_dataset, eval_items): + """Remote mode does not create or modify a local file.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + dataset_remote.save(eval_items) + assert dataset_remote._filepath is None + + def test_save_remote_mode_logs_warning(self, dataset_remote, mock_opik, mock_remote_dataset, eval_items): + """Remote mode logs a warning that the local file won't be updated.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.logger") as mock_logger: + dataset_remote.save(eval_items) + warning_messages = [c[0][0] for c in mock_logger.warning.call_args_list] + assert any("uploads to remote only" in msg for msg in warning_messages) + + def test_save_missing_input_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): + """Item missing 'input' key raises DatasetError before any upload.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + bad_items = [{"expected_output": "no input here"}] + with pytest.raises(DatasetError, match="missing required 'input'"): + dataset_local.save(bad_items) + mock_remote_dataset.insert.assert_not_called() + + +class TestExists: + """Test the _exists() method.""" + + def test_returns_true_when_dataset_exists(self, dataset_local, mock_opik, mock_remote_dataset): + """Returns True when get_dataset succeeds.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + assert dataset_local._exists() is True + + def test_returns_false_on_404(self, dataset_local, mock_opik): + """Returns False when get_dataset raises a 404 ApiError.""" + mock_opik.get_dataset.side_effect = make_api_error(404) + assert dataset_local._exists() is False + + def test_non_404_api_error_raises_dataset_error(self, dataset_local, mock_opik): + """Non-404 ApiError is wrapped in DatasetError.""" + mock_opik.get_dataset.side_effect = make_api_error(500) + with pytest.raises(DatasetError, match="Opik API error while checking dataset"): + dataset_local._exists() + + def test_connection_error_raises_dataset_error(self, dataset_local, mock_opik): + """Connection-level errors are wrapped in DatasetError.""" + mock_opik.get_dataset.side_effect = ConnectionRefusedError("Connection refused") + with pytest.raises(DatasetError, match="Failed to connect to Opik while checking dataset"): + dataset_local._exists() + + +class TestDescribe: + """Test the _describe() method.""" + + def test_describe_returns_all_fields(self, dataset_local): + """_describe returns the expected keys.""" + desc = dataset_local._describe() + assert desc["dataset_name"] == "test-dataset" + assert desc["sync_policy"] == "local" + assert "filepath" in desc + assert "metadata" in desc + + def test_describe_filepath_none_when_not_set(self, dataset_remote): + """filepath is None in _describe when not configured.""" + assert dataset_remote._describe()["filepath"] is None + + def test_describe_metadata_included(self, mock_credentials, mock_opik, mock_remote_dataset): + """metadata dict is returned in _describe.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + metadata={"project": "evaluation"}, + ) + assert ds._describe()["metadata"] == {"project": "evaluation"} + + +class TestPreview: + """Test the preview() method.""" + + def test_preview_existing_json_file(self, dataset_local, eval_items): + """Returns a JSON-parseable preview for an existing file.""" + preview = dataset_local.preview() + parsed = json.loads(str(preview)) + assert isinstance(parsed, list) + assert len(parsed) == len(eval_items) + + def test_preview_nonexistent_file(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): + """Returns a descriptive message when the local file does not exist.""" + mock_opik.get_dataset.return_value = mock_remote_dataset + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=str(tmp_path / "missing.json"), + ) + assert "does not exist" in str(ds.preview()) + + def test_preview_no_filepath(self, dataset_remote): + """Returns a descriptive message when no filepath is configured.""" + assert "No filepath configured" in str(dataset_remote.preview()) + + def test_preview_non_serialisable_data_returns_message( + self, tmp_path, mock_credentials, mock_opik + ): + """Non-JSON-serialisable local data returns a graceful error message instead of raising.""" + + filepath = tmp_path / "eval.json" + filepath.write_text(json.dumps([{"input": "x"}])) + + ds = OpikEvaluationDataset( + dataset_name="ds", + credentials=mock_credentials, + filepath=str(filepath), + ) + + with patch.object(ds.file_dataset, "load", return_value=[{"input": datetime.date(2024, 1, 1)}]): + result = str(ds.preview()) + + assert "Could not serialise" in result From 117706a3c5875c56e0ddec8176e0b290bf8fa781 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Thu, 23 Apr 2026 20:29:12 -0500 Subject: [PATCH 10/21] Simplify saving / loading, address PR comments. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/_base.py | 49 +++- .../huggingface/arrow_dataset.py | 4 + .../kedro_datasets/huggingface/csv_dataset.py | 78 ++++++- .../huggingface/json_dataset.py | 78 ++++++- .../huggingface/parquet_dataset.py | 78 ++++++- kedro-datasets/tests/huggingface/conftest.py | 10 +- .../tests/huggingface/test_arrow_dataset.py | 65 +++--- .../huggingface/test_filesystem_datasets.py | 210 +++++++++++------- 8 files changed, 410 insertions(+), 162 deletions(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/_base.py b/kedro-datasets/kedro_datasets/huggingface/_base.py index 104ec9fc0..0538f567c 100644 --- a/kedro-datasets/kedro_datasets/huggingface/_base.py +++ b/kedro-datasets/kedro_datasets/huggingface/_base.py @@ -91,6 +91,26 @@ def __init__( # noqa: PLR0913 glob_function=self._fs.glob, ) + # For non-Arrow datasets, we have to validate that, if we were given + # a directory, the user also provided ``data_files`` in the load_args. + filepath_str = get_filepath_str(self._get_load_path(), self._protocol) + self._path_is_dir = not PurePosixPath(filepath_str).suffix or self._fs.isdir( + filepath_str + ) + + self._validate_load_paths() + + def _validate_load_paths(self): + """If we're loading from a directory, we have to assume this is a DatasetDict. + Non-Arrow datasets cannot do a ``datasets.load_from_disk`` without ``data_files`` + specified in the arguments. + """ + if self._path_is_dir and "data_files" not in self._load_args: + raise DatasetError( + f"{type(self).__name__} cannot load from a directory " + "without specifying ``data_files`` in ``load_args``." + ) + def load(self) -> DatasetLike: load_path = get_filepath_str(self._get_load_path(), self._protocol) return self._load_dataset(load_path) @@ -121,12 +141,30 @@ def save(self, data: DatasetLike) -> None: self._invalidate_cache() + def _build_data_files(self) -> str | dict[str, str]: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + # If this is a directory, we're expecting to load a DatasetDict. + if self._path_is_dir: + data_files = self._load_args["data_files"] + return { + split: os.path.join(load_path, filename) + for split, filename in data_files.items() + } + + # Otherwise, just return the path to the Dataset to be loaded. + return load_path + def _load_dataset(self, load_path: str) -> DatasetLike: + data_files: str | dict[str, str] = self._build_data_files() + + load_args = deepcopy(self._load_args) + load_args.pop("data_files", None) + return load_dataset( # nosec self.BUILDER, - data_files=load_path, + data_files=data_files, storage_options=self._storage_options, - **self._load_args, + **load_args, ) def _save_dataset(self, data: Dataset, save_path: str) -> None: @@ -144,14 +182,9 @@ def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: """ self._fs.mkdirs(save_path, exist_ok=True) ext = self.EXTENSION - saver = f"to_{self.BUILDER}" for split, split_ds in data.items(): split_path = f"{save_path}/{split}{ext}" - getattr(split_ds, saver)( - split_path, - storage_options=self._storage_options, - **self._save_args, - ) + self._save_dataset(split_ds, split_path) def _describe(self) -> dict[str, Any]: return { diff --git a/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py b/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py index c040b38e0..64c21d4c0 100644 --- a/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py @@ -50,6 +50,10 @@ class ArrowDataset(FilesystemDataset): BUILDER: ClassVar[str] = "arrow" EXTENSION: ClassVar[str] = ".arrow" + def _validate_load_paths(self): + """Override to do nothing. Path validation handled by ``load_from_disk``.""" + pass + def _load_dataset(self, load_path: str) -> DatasetLike: return load_from_disk( load_path, diff --git a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py index fc77c2725..ea2522cfd 100644 --- a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py @@ -9,12 +9,15 @@ class CSVDataset(FilesystemDataset): """``CSVDataset`` loads/saves Hugging Face ``Dataset`` and ``DatasetDict`` objects to/from CSV files. - Iterable variants (``IterableDataset``, ``IterableDatasetDict``) - are materialised before saving. + Note that ``datasets`` loads a single file as a ``datasets.DatasetDict`` + with a single key called ``"train"``. You can get around this by specifying + ``split`` in the ``load_args``. See examples for more info. Examples: Using the - [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a single file. Will be loaded as a ``datasets.DatasetDict`` with a single key + ``"train"``: ```yaml reviews: @@ -23,23 +26,76 @@ class CSVDataset(FilesystemDataset): ``` Using the - [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.DatasetDict`` from a single file: - >>> from datasets import Dataset >>> from kedro_datasets.huggingface.csv_dataset import ( ... CSVDataset, ... ) >>> - >>> data = Dataset.from_dict( - ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + >>> dataset = CSVDataset(path=tmp_path / "data.csv") + >>> loaded = dataset.load() + >>> assert "train" in loaded + + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a ``datasets.Dataset`` from a single file: + + ```yaml + reviews: + type: huggingface.CSVDataset + path: data/01_raw/reviews.csv + load_args: + split: train + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.Dataset`` from a single file: + + >>> from kedro_datasets.huggingface.csv_dataset import ( + ... CSVDataset, + ... ) + >>> + >>> dataset = CSVDataset( + ... path=tmp_path / "data.csv", + ... load_args={"split": "train"}, + ... ) + >>> loaded = dataset.load() + >>> assert type(loaded.shape) is tuple # No "train" key. + + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a ``datasets.DatasetDict`` from a directory of files: + + ```yaml + reviews: + type: huggingface.CSVDataset + path: data/01_raw/reviews + load_args: + data_files: + labels: labels.csv + data: data.csv + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.DatasetDict`` from a directory of files: + + >>> from kedro_datasets.huggingface.csv_dataset import ( + ... CSVDataset, ... ) >>> >>> dataset = CSVDataset( - ... path=tmp_path / "test_hf_dataset.csv" + ... path=tmp_path, + ... load_args={ + ... "data_files": { + ... "labels": "labels.csv", + ... "data": "data.csv", + ... } + ... }, ... ) - >>> dataset.save(data) - >>> reloaded = dataset.load() - >>> assert reloaded.to_dict() == data.to_dict() + >>> loaded = dataset.load() """ BUILDER: ClassVar[str] = "csv" diff --git a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py index c3782d58a..10674d3af 100644 --- a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py @@ -9,12 +9,15 @@ class JSONDataset(FilesystemDataset): """``JSONDataset`` loads/saves Hugging Face ``Dataset`` and ``DatasetDict`` objects to/from JSON files. - Iterable variants (``IterableDataset``, ``IterableDatasetDict``) - are materialised before saving. + Note that ``datasets`` loads a single file as a ``datasets.DatasetDict`` + with a single key called ``"train"``. You can get around this by specifying + ``split`` in the ``load_args``. See examples for more info. Examples: Using the - [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a single file. Will be loaded as a ``datasets.DatasetDict`` with a single key + ``"train"``: ```yaml reviews: @@ -23,23 +26,76 @@ class JSONDataset(FilesystemDataset): ``` Using the - [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.DatasetDict`` from a single file: - >>> from datasets import Dataset >>> from kedro_datasets.huggingface.json_dataset import ( ... JSONDataset, ... ) >>> - >>> data = Dataset.from_dict( - ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + >>> dataset = JSONDataset(path=tmp_path / "data.json") + >>> loaded = dataset.load() + >>> assert "train" in loaded + + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a ``datasets.Dataset`` from a single file: + + ```yaml + reviews: + type: huggingface.JSONDataset + path: data/01_raw/reviews.json + load_args: + split: train + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.Dataset`` from a single file: + + >>> from kedro_datasets.huggingface.json_dataset import ( + ... JSONDataset, + ... ) + >>> + >>> dataset = JSONDataset( + ... path=tmp_path / "data.json", + ... load_args={"split": "train"}, + ... ) + >>> loaded = dataset.load() + >>> assert type(loaded.shape) is tuple # No "train" key. + + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a ``datasets.DatasetDict`` from a directory of files: + + ```yaml + reviews: + type: huggingface.JSONDataset + path: data/01_raw/reviews + load_args: + data_files: + labels: labels.json + data: data.json + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.DatasetDict`` from a directory of files: + + >>> from kedro_datasets.huggingface.json_dataset import ( + ... JSONDataset, ... ) >>> >>> dataset = JSONDataset( - ... path=tmp_path / "test_hf_dataset.json" + ... path=tmp_path, + ... load_args={ + ... "data_files": { + ... "labels": "labels.json", + ... "data": "data.json", + ... } + ... }, ... ) - >>> dataset.save(data) - >>> reloaded = dataset.load() - >>> assert reloaded.to_dict() == data.to_dict() + >>> loaded = dataset.load() """ BUILDER: ClassVar[str] = "json" diff --git a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py index 9a7cf8921..6f20eb839 100644 --- a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py @@ -10,12 +10,15 @@ class ParquetDataset(FilesystemDataset): ``DatasetDict`` objects to/from `Parquet `_ files. - Iterable variants (``IterableDataset``, ``IterableDatasetDict``) - are materialised before saving. + Note that ``datasets`` loads a single file as a ``datasets.DatasetDict`` + with a single key called ``"train"``. You can get around this by specifying + ``split`` in the ``load_args``. See examples for more info. Examples: Using the - [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a single file. Will be loaded as a ``datasets.DatasetDict`` with a single key + ``"train"``: ```yaml reviews: @@ -24,23 +27,76 @@ class ParquetDataset(FilesystemDataset): ``` Using the - [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.DatasetDict`` from a single file: - >>> from datasets import Dataset >>> from kedro_datasets.huggingface.parquet_dataset import ( ... ParquetDataset, ... ) >>> - >>> data = Dataset.from_dict( - ... {"col1": [1, 2, 3], "col2": ["a", "b", "c"]} + >>> dataset = ParquetDataset(path=tmp_path / "data.parquet") + >>> loaded = dataset.load() + >>> assert "train" in loaded + + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a ``datasets.Dataset`` from a single file: + + ```yaml + reviews: + type: huggingface.ParquetDataset + path: data/01_raw/reviews.parquet + load_args: + split: train + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.Dataset`` from a single file: + + >>> from kedro_datasets.huggingface.parquet_dataset import ( + ... ParquetDataset, + ... ) + >>> + >>> dataset = ParquetDataset( + ... path=tmp_path / "data.parquet", + ... load_args={"split": "train"}, + ... ) + >>> loaded = dataset.load() + >>> assert type(loaded.shape) is tuple # No "train" key. + + Using the + [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/) + to load a ``datasets.DatasetDict`` from a directory of files: + + ```yaml + reviews: + type: huggingface.ParquetDataset + path: data/01_raw/reviews + load_args: + data_files: + labels: labels.parquet + data: data.parquet + ``` + + Using the + [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) + to load a ``datasets.DatasetDict`` from a directory of files: + + >>> from kedro_datasets.huggingface.parquet_dataset import ( + ... ParquetDataset, ... ) >>> >>> dataset = ParquetDataset( - ... path=tmp_path / "test_hf_dataset.parquet" + ... path=tmp_path, + ... load_args={ + ... "data_files": { + ... "labels": "labels.parquet", + ... "data": "data.parquet", + ... } + ... }, ... ) - >>> dataset.save(data) - >>> reloaded = dataset.load() - >>> assert reloaded.to_dict() == data.to_dict() + >>> loaded = dataset.load() """ BUILDER: ClassVar[str] = "parquet" diff --git a/kedro-datasets/tests/huggingface/conftest.py b/kedro-datasets/tests/huggingface/conftest.py index 694f767b7..7254f9c99 100644 --- a/kedro-datasets/tests/huggingface/conftest.py +++ b/kedro-datasets/tests/huggingface/conftest.py @@ -10,7 +10,7 @@ @pytest.fixture -def dataset(): +def hf_dataset(): return Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) @@ -18,8 +18,8 @@ def dataset(): def dataset_dict(): return DatasetDict( { - "train": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), - "test": Dataset.from_dict({"col1": [3], "col2": ["c"]}), + "data": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), + "labels": Dataset.from_dict({"col1": [3], "col2": ["c"]}), } ) @@ -35,10 +35,10 @@ def iterable_dataset(): def iterable_dataset_dict(): return IterableDatasetDict( { - "train": Dataset.from_dict( + "data": Dataset.from_dict( {"col1": [1, 2], "col2": ["a", "b"]} ).to_iterable_dataset(), - "test": Dataset.from_dict( + "labels": Dataset.from_dict( {"col1": [3], "col2": ["c"]} ).to_iterable_dataset(), } diff --git a/kedro-datasets/tests/huggingface/test_arrow_dataset.py b/kedro-datasets/tests/huggingface/test_arrow_dataset.py index b993f481a..eee3b9612 100644 --- a/kedro-datasets/tests/huggingface/test_arrow_dataset.py +++ b/kedro-datasets/tests/huggingface/test_arrow_dataset.py @@ -31,27 +31,37 @@ def versioned_arrow_dataset(path_arrow, load_version, save_version): return ArrowDataset(path=path_arrow, version=Version(load_version, save_version)) +@pytest.fixture +def load_version(request): + return getattr(request, "param", "2019-01-01T23.59.59.999Z") + + +@pytest.fixture +def save_version(request): + return getattr(request, "param", "2019-01-01T23.59.59.999Z") + + class TestArrowDataset: - def test_save_and_load_dataset(self, arrow_dataset, dataset): + def test_save_and_load_dataset(self, arrow_dataset, hf_dataset): """Test saving and reloading a Dataset.""" - arrow_dataset.save(dataset) + arrow_dataset.save(hf_dataset) reloaded = arrow_dataset.load() assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() + assert reloaded.to_dict() == hf_dataset.to_dict() def test_save_and_load_dataset_dict(self, arrow_dataset, dataset_dict): """Test saving and reloading a DatasetDict.""" arrow_dataset.save(dataset_dict) reloaded = arrow_dataset.load() assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} + assert set(reloaded.keys()) == {"data", "labels"} for split in dataset_dict: assert reloaded[split].to_dict() == dataset_dict[split].to_dict() - def test_exists(self, arrow_dataset, dataset): + def test_exists(self, arrow_dataset, hf_dataset): """Test `exists` method for both existing and nonexistent dataset.""" assert not arrow_dataset.exists() - arrow_dataset.save(dataset) + arrow_dataset.save(hf_dataset) assert arrow_dataset.exists() def test_exists_dataset_dict(self, arrow_dataset, dataset_dict): @@ -119,13 +129,13 @@ def test_protocol_usage(self, path, instance_type): assert str(dataset._filepath) == resolved assert isinstance(dataset._filepath, PurePosixPath) - def test_pathlike_path(self, tmp_path, dataset): + def test_pathlike_path(self, tmp_path, hf_dataset): """Test that os.PathLike paths are supported.""" path = tmp_path / "test_hf_pathlike" ds = ArrowDataset(path=path) - ds.save(dataset) + ds.save(hf_dataset) reloaded = ds.load() - assert reloaded.to_dict() == dataset.to_dict() + assert reloaded.to_dict() == hf_dataset.to_dict() def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value @@ -148,19 +158,19 @@ def test_version_str_repr(self, load_version, save_version): assert "version" not in str(ds) assert path in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" + ver_str = f"version=Version(load='{load_version}', save='{save_version}')" assert ver_str in str(ds_versioned) assert "ArrowDataset" in str(ds_versioned) assert "ArrowDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_arrow_dataset, dataset): + def test_save_and_load(self, versioned_arrow_dataset, hf_dataset): """Test that saved and reloaded data matches the original one for the versioned dataset.""" - versioned_arrow_dataset.save(dataset) + versioned_arrow_dataset.save(hf_dataset) reloaded = versioned_arrow_dataset.load() - assert reloaded.to_dict() == dataset.to_dict() + assert reloaded.to_dict() == hf_dataset.to_dict() def test_save_and_load_dataset_dict(self, versioned_arrow_dataset, dataset_dict): """Test versioned save and reload with DatasetDict.""" @@ -186,28 +196,22 @@ def test_save_and_load_iterable_dataset_dict( with pytest.raises(DatasetError, match=pattern): versioned_arrow_dataset.save(iterable_dataset_dict) - def test_no_versions(self, versioned_arrow_dataset): - """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for kedro_datasets.huggingface.arrow_dataset.ArrowDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_arrow_dataset.load() - - def test_exists(self, versioned_arrow_dataset, dataset): + def test_exists(self, versioned_arrow_dataset, hf_dataset): """Test `exists` method invocation for versioned dataset.""" assert not versioned_arrow_dataset.exists() - versioned_arrow_dataset.save(dataset) + versioned_arrow_dataset.save(hf_dataset) assert versioned_arrow_dataset.exists() - def test_prevent_overwrite(self, versioned_arrow_dataset, dataset): + def test_prevent_overwrite(self, versioned_arrow_dataset, hf_dataset): """Check the error when attempting to override the dataset if the corresponding version already exists.""" - versioned_arrow_dataset.save(dataset) + versioned_arrow_dataset.save(hf_dataset) pattern = ( r"Save path \'.+\' for kedro_datasets.huggingface.arrow_dataset.ArrowDataset\(.+\) must " r"not exist if versioning is enabled\." ) with pytest.raises(DatasetError, match=pattern): - versioned_arrow_dataset.save(dataset) + versioned_arrow_dataset.save(hf_dataset) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -216,7 +220,7 @@ def test_prevent_overwrite(self, versioned_arrow_dataset, dataset): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_arrow_dataset, load_version, save_version, dataset + self, versioned_arrow_dataset, load_version, save_version, hf_dataset ): """Check the warning when saving to the path that differs from the subsequent load path.""" @@ -226,16 +230,7 @@ def test_save_version_warning( r"kedro_datasets.huggingface.arrow_dataset.ArrowDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_arrow_dataset.save(dataset) - - def test_http_filesystem_no_versioning(self): - pattern = "Versioning is not supported for HTTP protocols." - - with pytest.raises(DatasetError, match=pattern): - ArrowDataset( - path="https://example.com/hf_data", - version=Version(None, None), - ) + versioned_arrow_dataset.save(hf_dataset) def test_save_invalid_type_versioned(self, versioned_arrow_dataset): """Check the error when saving an unsupported type through versioned dataset.""" diff --git a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py index 2a499965a..c5d96d163 100644 --- a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py +++ b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py @@ -1,3 +1,4 @@ +import os import re from pathlib import PurePosixPath @@ -38,7 +39,7 @@ def fmt(request): @pytest.fixture -def dataset_cls(fmt): +def kedro_dataset_cls(fmt): return fmt[0] @@ -47,6 +48,14 @@ def extension(fmt): return fmt[1] +@pytest.fixture +def dataset_dict_data_files(extension): + return { + "data": f"data{extension}", + "labels": f"labels{extension}", + } + + @pytest.fixture def path_file(tmp_path, extension): return (tmp_path / f"test{extension}").as_posix() @@ -58,144 +67,183 @@ def path_dir(tmp_path): @pytest.fixture -def fs_dataset(dataset_cls, path_file, save_args, load_args, fs_args): - return dataset_cls( - path=path_file, - save_args=save_args, - load_args=load_args, - fs_args=fs_args, - ) +def load_version(): + return "2019-01-01T23.59.59.999Z" @pytest.fixture -def fs_dataset_dir(dataset_cls, path_dir): - return dataset_cls(path=path_dir) +def save_version(): + return "2019-01-01T23.59.59.999Z" @pytest.fixture -def versioned_fs_dataset(dataset_cls, path_file, load_version, save_version): - return dataset_cls(path=path_file, version=Version(load_version, save_version)) +def versioned_fs_dataset(kedro_dataset_cls, path_file, load_version, save_version): + return kedro_dataset_cls( + path=path_file, version=Version(load_version, save_version) + ) class TestFilesystemDataset: - def test_save_and_load_dataset(self, fs_dataset, dataset): - """A single-file load returns a Dataset (auto-unwrapped).""" - fs_dataset.save(dataset) - reloaded = fs_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() + def test_save_and_load_dataset(self, hf_dataset, kedro_dataset_cls, path_file): + """A single-file load returns a DatasetDict with single key "train".""" + kedro_dataset = kedro_dataset_cls(path=path_file) + kedro_dataset.save(hf_dataset) + reloaded = kedro_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert "train" in reloaded + assert reloaded["train"].to_dict() == hf_dataset.to_dict() - def test_save_and_load_dataset_with_split(self, dataset_cls, path_file, dataset): + def test_save_and_load_dataset_with_split( + self, hf_dataset, kedro_dataset_cls, path_file + ): """With split in load_args, the explicit split is respected.""" - ds = dataset_cls(path=path_file, load_args={"split": "train"}) - ds.save(dataset) - reloaded = ds.load() + kedro_dataset = kedro_dataset_cls(path=path_file, load_args={"split": "train"}) + kedro_dataset.save(hf_dataset) + reloaded = kedro_dataset.load() assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() + assert reloaded.to_dict() == hf_dataset.to_dict() + + def test_directory_no_data_files_error(self, kedro_dataset_cls, path_dir): + with pytest.raises(DatasetError, match=r"from a directory"): + kedro_dataset_cls(path=path_dir) + + def test_build_data_files( + self, kedro_dataset_cls, path_dir, dataset_dict_data_files + ): + kedro_dataset = kedro_dataset_cls( + path=path_dir, load_args={"data_files": dataset_dict_data_files} + ) + + built_data_files = kedro_dataset._build_data_files() + + for split, filename in dataset_dict_data_files.items(): + assert split in built_data_files + assert built_data_files[split] == os.path.join(path_dir, filename) + + def test_save_and_load_dataset_dict( + self, dataset_dict, kedro_dataset_cls, path_dir, dataset_dict_data_files + ): + kedro_dataset = kedro_dataset_cls( + path=path_dir, load_args={"data_files": dataset_dict_data_files} + ) + kedro_dataset.save(dataset_dict) - def test_save_and_load_dataset_dict(self, fs_dataset_dir, dataset_dict): - fs_dataset_dir.save(dataset_dict) - reloaded = fs_dataset_dir.load() + reloaded = kedro_dataset.load() assert isinstance(reloaded, DatasetDict) - assert set(reloaded.keys()) == {"train", "test"} - for split in dataset_dict: - assert reloaded[split].to_dict() == dataset_dict[split].to_dict() + assert set(reloaded.keys()) == dataset_dict_data_files.keys() + for key in dataset_dict_data_files.keys(): + assert reloaded[key].to_dict() == dataset_dict[key].to_dict() - def test_save_and_load_iterable_dataset(self, fs_dataset, iterable_dataset): + def test_save_and_load_iterable_dataset( + self, iterable_dataset, kedro_dataset_cls, path_file + ): + kedro_dataset = kedro_dataset_cls(path=path_file) with pytest.raises(DatasetError, match=r"got iterable dataset"): - fs_dataset.save(iterable_dataset) + kedro_dataset.save(iterable_dataset) def test_save_and_load_iterable_dataset_dict( - self, fs_dataset_dir, iterable_dataset_dict + self, + iterable_dataset_dict, + kedro_dataset_cls, + path_dir, + dataset_dict_data_files, ): + kedro_dataset = kedro_dataset_cls( + path=path_dir, load_args={"data_files": dataset_dict_data_files} + ) with pytest.raises(DatasetError, match=r"got iterable dataset"): - fs_dataset_dir.save(iterable_dataset_dict) + kedro_dataset.save(iterable_dataset_dict) - def test_exists(self, fs_dataset, dataset): - assert not fs_dataset.exists() - fs_dataset.save(dataset) - assert fs_dataset.exists() + def test_exists(self, hf_dataset, kedro_dataset_cls, path_file): + kedro_dataset = kedro_dataset_cls(path=path_file) + kedro_dataset.save(hf_dataset) + assert kedro_dataset.exists() - def test_load_missing_dataset(self, fs_dataset, dataset_cls): - pattern = ( - rf"Failed while loading data from dataset {_qualname(dataset_cls)}\(.*\)" - ) + def test_load_missing_dataset(self, kedro_dataset_cls, path_file): + kedro_dataset = kedro_dataset_cls(path=path_file) + + pattern = rf"Failed while loading data from dataset {_qualname(kedro_dataset_cls)}\(.*\)" with pytest.raises(DatasetError, match=pattern): - fs_dataset.load() + kedro_dataset.load() + + def test_save_invalid_type(self, kedro_dataset_cls, path_file): + kedro_dataset = kedro_dataset_cls(path=path_file) - def test_save_invalid_type(self, fs_dataset, dataset_cls): - pattern = rf"{dataset_cls.__name__} only supports" + pattern = rf"{kedro_dataset_cls.__name__} only supports" with pytest.raises(DatasetError, match=pattern): - fs_dataset.save({"not": "a dataset"}) + kedro_dataset.save({"not": "a dataset"}) @pytest.mark.parametrize("base_path,instance_type", PROTOCOLS) - def test_protocol_usage(self, dataset_cls, extension, base_path, instance_type): + def test_protocol_usage( + self, kedro_dataset_cls, extension, base_path, instance_type, mocker + ): + # Skip checking directory as it would require permissions for remote filesystems. + mocker.patch.object(instance_type, "isdir", return_value=False) + path = f"{base_path}{extension}" - ds = dataset_cls(path=path) + ds = kedro_dataset_cls(path=path) assert isinstance(ds._fs, instance_type) resolved = path.split(PROTOCOL_DELIMITER, 1)[-1] assert str(ds._filepath) == resolved assert isinstance(ds._filepath, PurePosixPath) - def test_pathlike_path(self, dataset_cls, tmp_path, extension, dataset): + def test_pathlike_path(self, hf_dataset, kedro_dataset_cls, tmp_path, extension): path = tmp_path / f"test_pathlike{extension}" - ds = dataset_cls(path=path) - ds.save(dataset) + ds = kedro_dataset_cls(path=path) + ds.save(hf_dataset) reloaded = ds.load() - assert reloaded.to_dict() == dataset.to_dict() + assert isinstance(reloaded, DatasetDict) + assert reloaded["train"].to_dict() == hf_dataset.to_dict() - def test_catalog_release(self, dataset_cls, extension, mocker): + def test_catalog_release(self, kedro_dataset_cls, path_file, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value - path = f"test{extension}" - ds = dataset_cls(path=path) + fs_mock.isdir.return_value = False + ds = kedro_dataset_cls(path=path_file) ds.release() - fs_mock.invalidate_cache.assert_called_once_with(path) + fs_mock.invalidate_cache.assert_called_once_with(path_file) class TestFilesystemDatasetVersioned: - def test_version_str_repr(self, dataset_cls, extension, load_version, save_version): + def test_version_str_repr( + self, kedro_dataset_cls, extension, load_version, save_version + ): path = f"test{extension}" - ds = dataset_cls(path=path) - ds_versioned = dataset_cls( + ds = kedro_dataset_cls(path=path) + ds_versioned = kedro_dataset_cls( path=path, version=Version(load_version, save_version) ) assert path in str(ds) assert "version" not in str(ds) assert path in str(ds_versioned) - ver_str = f"version=Version(load={load_version}, save='{save_version}')" + ver_str = f"version=Version(load='{load_version}', save='{save_version}')" assert ver_str in str(ds_versioned) - assert dataset_cls.__name__ in str(ds_versioned) + assert kedro_dataset_cls.__name__ in str(ds_versioned) - def test_save_and_load(self, versioned_fs_dataset, dataset): - versioned_fs_dataset.save(dataset) + def test_save_and_load(self, hf_dataset, versioned_fs_dataset): + versioned_fs_dataset.save(hf_dataset) reloaded = versioned_fs_dataset.load() - assert isinstance(reloaded, Dataset) - assert reloaded.to_dict() == dataset.to_dict() + assert isinstance(reloaded, DatasetDict) + assert reloaded["train"].to_dict() == hf_dataset.to_dict() - def test_no_versions(self, versioned_fs_dataset, dataset_cls): - pattern = rf"Did not find any versions for {_qualname(dataset_cls)}\(.+\)" + def test_no_versions(self, kedro_dataset_cls, path_file): + """With Version(None, None), construction fails when no saved versions exist.""" + pattern = rf"Did not find any versions for {_qualname(kedro_dataset_cls)}\(.+\)" with pytest.raises(DatasetError, match=pattern): - versioned_fs_dataset.load() + kedro_dataset_cls(path=path_file, version=Version(None, None)) - def test_exists(self, versioned_fs_dataset, dataset): + def test_exists(self, hf_dataset, versioned_fs_dataset): assert not versioned_fs_dataset.exists() - versioned_fs_dataset.save(dataset) + versioned_fs_dataset.save(hf_dataset) assert versioned_fs_dataset.exists() - def test_prevent_overwrite(self, versioned_fs_dataset, dataset_cls, dataset): - versioned_fs_dataset.save(dataset) + def test_prevent_overwrite( + self, hf_dataset, versioned_fs_dataset, kedro_dataset_cls + ): + versioned_fs_dataset.save(hf_dataset) pattern = ( - rf"Save path \'.+\' for {_qualname(dataset_cls)}\(.+\) must " + rf"Save path \'.+\' for {_qualname(kedro_dataset_cls)}\(.+\) must " r"not exist if versioning is enabled\." ) with pytest.raises(DatasetError, match=pattern): - versioned_fs_dataset.save(dataset) - - def test_http_filesystem_no_versioning(self, dataset_cls, extension): - pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DatasetError, match=pattern): - dataset_cls( - path=f"https://example.com/data{extension}", - version=Version(None, None), - ) + versioned_fs_dataset.save(hf_dataset) From b222ccf13ae17e584b75d041d6c99a374c813f87 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Thu, 23 Apr 2026 20:38:00 -0500 Subject: [PATCH 11/21] Fix RELEASE.md. Signed-off-by: iwhalen --- kedro-datasets/RELEASE.md | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 378a402f3..0100a1172 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -2,6 +2,7 @@ ## Major features and improvements +- Add Hugging Face datasets: `ArrowDataset`, `ParquetDataset`, `JSONDataset`, `CSVDataset`. - Kedro-Datasets is now compatible with Python 3.14, except for `tensorflow.TensorFlowModelDataset` and `geopandas.GenericDataset`. - Added the following new **experimental** datasets: @@ -39,18 +40,6 @@ Many thanks to the following Kedroids for contributing PRs to this release: - [Datascienceio](https://github.com/datascienceio) - [Guillaume Tauzin](https://github.com/gtauzin) - -# Release 9.3.0 - -## Major features and improvements - -- Add Hugging Face datasets: `ArrowDataset`, `ParquetDataset`, `JSONDataset`, `CSVDataset`. - -## Bug fixes and other changes -## Community contributions - -Many thanks to the following Kedroids for contributing PRs to this release: - - [iwhalen](https://github.com/iwhalen) # Release 9.3.0 From a2abdda729c1ed1a85766ba1b82c1d0433fa1d1b Mon Sep 17 00:00:00 2001 From: iwhalen Date: Thu, 23 Apr 2026 20:42:39 -0500 Subject: [PATCH 12/21] Clean up deleted file. Signed-off-by: iwhalen --- .../kedro_datasets_experimental/opik.OpikEvaluationDataset.md | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md diff --git a/kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md b/kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md deleted file mode 100644 index 7fd57c76b..000000000 --- a/kedro-datasets/docs/api/kedro_datasets_experimental/opik.OpikEvaluationDataset.md +++ /dev/null @@ -1,4 +0,0 @@ -::: kedro_datasets_experimental.opik.OpikEvaluationDataset - options: - members: true - show_source: true From 73b7cdec703fe10ed99e688055ad87cb9e1ee05d Mon Sep 17 00:00:00 2001 From: iwhalen Date: Wed, 29 Apr 2026 17:13:25 -0500 Subject: [PATCH 13/21] Fix multiple definitions of DatasetLike. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/hugging_face_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py index f7796830e..dfd0eb412 100644 --- a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py @@ -12,7 +12,7 @@ from huggingface_hub import HfApi from kedro.io import AbstractDataset -DatasetLike: TypeAlias = Dataset | DatasetDict | IterableDataset | IterableDatasetDict +from ._base import DatasetLike class HFDataset(AbstractDataset[None, DatasetLike]): From ff9131c1a42953ba0ad798663c5229ec16480b11 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Wed, 29 Apr 2026 17:41:43 -0500 Subject: [PATCH 14/21] Fix doc references to iterables and data_files loading logic. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/_base.py | 26 +++---------------- .../huggingface/arrow_dataset.py | 9 +++---- .../kedro_datasets/huggingface/csv_dataset.py | 4 +++ .../huggingface/json_dataset.py | 4 +++ .../huggingface/parquet_dataset.py | 4 +++ .../huggingface/test_filesystem_datasets.py | 9 +++---- 6 files changed, 21 insertions(+), 35 deletions(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/_base.py b/kedro-datasets/kedro_datasets/huggingface/_base.py index 0538f567c..c0458adc2 100644 --- a/kedro-datasets/kedro_datasets/huggingface/_base.py +++ b/kedro-datasets/kedro_datasets/huggingface/_base.py @@ -91,26 +91,6 @@ def __init__( # noqa: PLR0913 glob_function=self._fs.glob, ) - # For non-Arrow datasets, we have to validate that, if we were given - # a directory, the user also provided ``data_files`` in the load_args. - filepath_str = get_filepath_str(self._get_load_path(), self._protocol) - self._path_is_dir = not PurePosixPath(filepath_str).suffix or self._fs.isdir( - filepath_str - ) - - self._validate_load_paths() - - def _validate_load_paths(self): - """If we're loading from a directory, we have to assume this is a DatasetDict. - Non-Arrow datasets cannot do a ``datasets.load_from_disk`` without ``data_files`` - specified in the arguments. - """ - if self._path_is_dir and "data_files" not in self._load_args: - raise DatasetError( - f"{type(self).__name__} cannot load from a directory " - "without specifying ``data_files`` in ``load_args``." - ) - def load(self) -> DatasetLike: load_path = get_filepath_str(self._get_load_path(), self._protocol) return self._load_dataset(load_path) @@ -122,7 +102,7 @@ def save(self, data: DatasetLike) -> None: "Before saving an iterable dataset " "you must materialize it into a `Dataset` or `DatasetDict`." ) - raise RuntimeError(msg) + raise DatasetError(msg) if not isinstance(data, Dataset | DatasetDict): msg = ( @@ -143,8 +123,8 @@ def save(self, data: DatasetLike) -> None: def _build_data_files(self) -> str | dict[str, str]: load_path = get_filepath_str(self._get_load_path(), self._protocol) - # If this is a directory, we're expecting to load a DatasetDict. - if self._path_is_dir: + + if "data_files" in self._load_args: data_files = self._load_args["data_files"] return { split: os.path.join(load_path, filename) diff --git a/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py b/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py index 64c21d4c0..34b463d91 100644 --- a/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/arrow_dataset.py @@ -14,8 +14,9 @@ class ArrowDataset(FilesystemDataset): `Arrow `_ format using ``save_to_disk`` / ``load_from_disk``. - Iterable variants (``IterableDataset``, ``IterableDatasetDict``) - are materialised before saving. + Saving ``IterableDataset`` or ``IterableDatasetDict`` objects is not + supported and will raise a ``DatasetError``. Materialize the iterable + dataset into a ``Dataset`` or ``DatasetDict`` before saving. Examples: Using the @@ -50,10 +51,6 @@ class ArrowDataset(FilesystemDataset): BUILDER: ClassVar[str] = "arrow" EXTENSION: ClassVar[str] = ".arrow" - def _validate_load_paths(self): - """Override to do nothing. Path validation handled by ``load_from_disk``.""" - pass - def _load_dataset(self, load_path: str) -> DatasetLike: return load_from_disk( load_path, diff --git a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py index ea2522cfd..7de2322b4 100644 --- a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py @@ -9,6 +9,10 @@ class CSVDataset(FilesystemDataset): """``CSVDataset`` loads/saves Hugging Face ``Dataset`` and ``DatasetDict`` objects to/from CSV files. + Saving ``IterableDataset`` or ``IterableDatasetDict`` objects is not + supported and will raise a ``DatasetError``. Materialize the iterable + dataset into a ``Dataset`` or ``DatasetDict`` before saving. + Note that ``datasets`` loads a single file as a ``datasets.DatasetDict`` with a single key called ``"train"``. You can get around this by specifying ``split`` in the ``load_args``. See examples for more info. diff --git a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py index 10674d3af..0fb245dc9 100644 --- a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py @@ -9,6 +9,10 @@ class JSONDataset(FilesystemDataset): """``JSONDataset`` loads/saves Hugging Face ``Dataset`` and ``DatasetDict`` objects to/from JSON files. + Saving ``IterableDataset`` or ``IterableDatasetDict`` objects is not + supported and will raise a ``DatasetError``. Materialize the iterable + dataset into a ``Dataset`` or ``DatasetDict`` before saving. + Note that ``datasets`` loads a single file as a ``datasets.DatasetDict`` with a single key called ``"train"``. You can get around this by specifying ``split`` in the ``load_args``. See examples for more info. diff --git a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py index 6f20eb839..960dcf67a 100644 --- a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py @@ -10,6 +10,10 @@ class ParquetDataset(FilesystemDataset): ``DatasetDict`` objects to/from `Parquet `_ files. + Saving ``IterableDataset`` or ``IterableDatasetDict`` objects is not + supported and will raise a ``DatasetError``. Materialize the iterable + dataset into a ``Dataset`` or ``DatasetDict`` before saving. + Note that ``datasets`` loads a single file as a ``datasets.DatasetDict`` with a single key called ``"train"``. You can get around this by specifying ``split`` in the ``load_args``. See examples for more info. diff --git a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py index c5d96d163..a56be4037 100644 --- a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py +++ b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py @@ -103,10 +103,6 @@ def test_save_and_load_dataset_with_split( assert isinstance(reloaded, Dataset) assert reloaded.to_dict() == hf_dataset.to_dict() - def test_directory_no_data_files_error(self, kedro_dataset_cls, path_dir): - with pytest.raises(DatasetError, match=r"from a directory"): - kedro_dataset_cls(path=path_dir) - def test_build_data_files( self, kedro_dataset_cls, path_dir, dataset_dict_data_files ): @@ -227,10 +223,11 @@ def test_save_and_load(self, hf_dataset, versioned_fs_dataset): assert reloaded["train"].to_dict() == hf_dataset.to_dict() def test_no_versions(self, kedro_dataset_cls, path_file): - """With Version(None, None), construction fails when no saved versions exist.""" + """With Version(None, None), loading fails when no saved versions exist.""" pattern = rf"Did not find any versions for {_qualname(kedro_dataset_cls)}\(.+\)" + ds = kedro_dataset_cls(path=path_file, version=Version(None, None)) with pytest.raises(DatasetError, match=pattern): - kedro_dataset_cls(path=path_file, version=Version(None, None)) + ds.load() def test_exists(self, hf_dataset, versioned_fs_dataset): assert not versioned_fs_dataset.exists() From 209bf249e0f868034ec248048a2c493412ef03cf Mon Sep 17 00:00:00 2001 From: iwhalen Date: Wed, 29 Apr 2026 17:49:50 -0500 Subject: [PATCH 15/21] Fix doctests. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/csv_dataset.py | 12 ++++++++++++ .../kedro_datasets/huggingface/json_dataset.py | 12 ++++++++++++ .../kedro_datasets/huggingface/parquet_dataset.py | 12 ++++++++++++ 3 files changed, 36 insertions(+) diff --git a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py index 7de2322b4..71503193d 100644 --- a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py @@ -33,11 +33,14 @@ class CSVDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.DatasetDict`` from a single file: + >>> from datasets import Dataset >>> from kedro_datasets.huggingface.csv_dataset import ( ... CSVDataset, ... ) >>> + >>> data = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) >>> dataset = CSVDataset(path=tmp_path / "data.csv") + >>> dataset.save(data) >>> loaded = dataset.load() >>> assert "train" in loaded @@ -57,14 +60,17 @@ class CSVDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.Dataset`` from a single file: + >>> from datasets import Dataset >>> from kedro_datasets.huggingface.csv_dataset import ( ... CSVDataset, ... ) >>> + >>> data = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) >>> dataset = CSVDataset( ... path=tmp_path / "data.csv", ... load_args={"split": "train"}, ... ) + >>> dataset.save(data) >>> loaded = dataset.load() >>> assert type(loaded.shape) is tuple # No "train" key. @@ -86,10 +92,15 @@ class CSVDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.DatasetDict`` from a directory of files: + >>> from datasets import Dataset, DatasetDict >>> from kedro_datasets.huggingface.csv_dataset import ( ... CSVDataset, ... ) >>> + >>> dataset_dict = DatasetDict({ + ... "labels": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), + ... "data": Dataset.from_dict({"col1": [3, 4], "col2": ["c", "d"]}), + ... }) >>> dataset = CSVDataset( ... path=tmp_path, ... load_args={ @@ -99,6 +110,7 @@ class CSVDataset(FilesystemDataset): ... } ... }, ... ) + >>> dataset.save(dataset_dict) >>> loaded = dataset.load() """ diff --git a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py index 0fb245dc9..e6c9d564b 100644 --- a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py @@ -33,11 +33,14 @@ class JSONDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.DatasetDict`` from a single file: + >>> from datasets import Dataset >>> from kedro_datasets.huggingface.json_dataset import ( ... JSONDataset, ... ) >>> + >>> data = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) >>> dataset = JSONDataset(path=tmp_path / "data.json") + >>> dataset.save(data) >>> loaded = dataset.load() >>> assert "train" in loaded @@ -57,14 +60,17 @@ class JSONDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.Dataset`` from a single file: + >>> from datasets import Dataset >>> from kedro_datasets.huggingface.json_dataset import ( ... JSONDataset, ... ) >>> + >>> data = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) >>> dataset = JSONDataset( ... path=tmp_path / "data.json", ... load_args={"split": "train"}, ... ) + >>> dataset.save(data) >>> loaded = dataset.load() >>> assert type(loaded.shape) is tuple # No "train" key. @@ -86,10 +92,15 @@ class JSONDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.DatasetDict`` from a directory of files: + >>> from datasets import Dataset, DatasetDict >>> from kedro_datasets.huggingface.json_dataset import ( ... JSONDataset, ... ) >>> + >>> dataset_dict = DatasetDict({ + ... "labels": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), + ... "data": Dataset.from_dict({"col1": [3, 4], "col2": ["c", "d"]}), + ... }) >>> dataset = JSONDataset( ... path=tmp_path, ... load_args={ @@ -99,6 +110,7 @@ class JSONDataset(FilesystemDataset): ... } ... }, ... ) + >>> dataset.save(dataset_dict) >>> loaded = dataset.load() """ diff --git a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py index 960dcf67a..2542a3f34 100644 --- a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py @@ -34,11 +34,14 @@ class ParquetDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.DatasetDict`` from a single file: + >>> from datasets import Dataset >>> from kedro_datasets.huggingface.parquet_dataset import ( ... ParquetDataset, ... ) >>> + >>> data = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) >>> dataset = ParquetDataset(path=tmp_path / "data.parquet") + >>> dataset.save(data) >>> loaded = dataset.load() >>> assert "train" in loaded @@ -58,14 +61,17 @@ class ParquetDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.Dataset`` from a single file: + >>> from datasets import Dataset >>> from kedro_datasets.huggingface.parquet_dataset import ( ... ParquetDataset, ... ) >>> + >>> data = Dataset.from_dict({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) >>> dataset = ParquetDataset( ... path=tmp_path / "data.parquet", ... load_args={"split": "train"}, ... ) + >>> dataset.save(data) >>> loaded = dataset.load() >>> assert type(loaded.shape) is tuple # No "train" key. @@ -87,10 +93,15 @@ class ParquetDataset(FilesystemDataset): [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/) to load a ``datasets.DatasetDict`` from a directory of files: + >>> from datasets import Dataset, DatasetDict >>> from kedro_datasets.huggingface.parquet_dataset import ( ... ParquetDataset, ... ) >>> + >>> dataset_dict = DatasetDict({ + ... "labels": Dataset.from_dict({"col1": [1, 2], "col2": ["a", "b"]}), + ... "data": Dataset.from_dict({"col1": [3, 4], "col2": ["c", "d"]}), + ... }) >>> dataset = ParquetDataset( ... path=tmp_path, ... load_args={ @@ -100,6 +111,7 @@ class ParquetDataset(FilesystemDataset): ... } ... }, ... ) + >>> dataset.save(dataset_dict) >>> loaded = dataset.load() """ From 013b3a8a8b6832d331d3a38cecb6d7d20b85c3a6 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Wed, 29 Apr 2026 17:59:23 -0500 Subject: [PATCH 16/21] Add documentation and a error check for mispatched data_files and datasetdict key names. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/_base.py | 19 ++++++++++++++++++- .../huggingface/test_filesystem_datasets.py | 17 +++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/_base.py b/kedro-datasets/kedro_datasets/huggingface/_base.py index c0458adc2..dafb815c1 100644 --- a/kedro-datasets/kedro_datasets/huggingface/_base.py +++ b/kedro-datasets/kedro_datasets/huggingface/_base.py @@ -54,7 +54,13 @@ def __init__( # noqa: PLR0913 version: Optional versioning configuration (see :class:`~kedro.io.core.Version`). load_args: Additional keyword arguments passed to the - underlying load function. + underlying load function. When loading a ``DatasetDict`` + from a directory, supply ``data_files`` as a mapping of + split name to filename (e.g. ``{"train": "train.csv"}``). + The keys must match the split names of the ``DatasetDict`` + being saved, and the filenames must use the correct + extension for the format (e.g. ``.csv`` for + ``CSVDataset``). save_args: Additional keyword arguments passed to the underlying save function. credentials: Credentials for the underlying filesystem @@ -160,6 +166,17 @@ def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: As a result, we have to call ``to_`` per split. """ + if "data_files" in self._load_args: + data_files_keys = set(self._load_args["data_files"].keys()) + split_names = set(data.keys()) + if data_files_keys != split_names: + msg = ( + f"DatasetDict split names {sorted(split_names)} do not match " + f"``load_args['data_files']`` keys {sorted(data_files_keys)}. " + "The data_files keys must match the DatasetDict split names " + "so the saved files can be found on load." + ) + raise DatasetError(msg) self._fs.mkdirs(save_path, exist_ok=True) ext = self.EXTENSION for split, split_ds in data.items(): diff --git a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py index a56be4037..60030f1cc 100644 --- a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py +++ b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py @@ -130,6 +130,23 @@ def test_save_and_load_dataset_dict( for key in dataset_dict_data_files.keys(): assert reloaded[key].to_dict() == dataset_dict[key].to_dict() + def test_save_dataset_dict_mismatched_data_files( + self, dataset_dict, kedro_dataset_cls, path_dir, extension + ): + """Saving a DatasetDict whose split names don't match data_files keys raises DatasetError.""" + kedro_dataset = kedro_dataset_cls( + path=path_dir, + load_args={ + # In the test fixture, we expect "data" and "labels". Not "train" and "test". + "data_files": { + "train": f"train{extension}", + "test": f"test{extension}", + } + }, + ) + with pytest.raises(DatasetError, match=r"do not match"): + kedro_dataset.save(dataset_dict) + def test_save_and_load_iterable_dataset( self, iterable_dataset, kedro_dataset_cls, path_file ): From 4affa127d1de6678242ecbe3a421cffc7afd2c21 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Mon, 11 May 2026 17:07:06 -0500 Subject: [PATCH 17/21] Address PR comments. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/hugging_face_dataset.py | 10 ++-------- kedro-datasets/tests/huggingface/test_arrow_dataset.py | 5 +++++ .../tests/huggingface/test_filesystem_datasets.py | 5 +++++ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py index dfd0eb412..7e1b31564 100644 --- a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py @@ -1,14 +1,8 @@ from __future__ import annotations -from typing import Any, TypeAlias +from typing import Any -from datasets import ( - Dataset, - DatasetDict, - IterableDataset, - IterableDatasetDict, - load_dataset, -) +from datasets import load_dataset from huggingface_hub import HfApi from kedro.io import AbstractDataset diff --git a/kedro-datasets/tests/huggingface/test_arrow_dataset.py b/kedro-datasets/tests/huggingface/test_arrow_dataset.py index eee3b9612..00dfe49f6 100644 --- a/kedro-datasets/tests/huggingface/test_arrow_dataset.py +++ b/kedro-datasets/tests/huggingface/test_arrow_dataset.py @@ -239,3 +239,8 @@ def test_save_invalid_type_versioned(self, versioned_arrow_dataset): ) with pytest.raises(DatasetError, match=pattern): versioned_arrow_dataset.save("not a dataset") + + def test_exists_no_versions(self, path_arrow): + """`exists()` returns False (not raises) when no versions are saved yet.""" + ds = ArrowDataset(path=path_arrow, version=Version(None, None)) + assert ds.exists() is False diff --git a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py index 60030f1cc..cb1d1ae2c 100644 --- a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py +++ b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py @@ -261,3 +261,8 @@ def test_prevent_overwrite( ) with pytest.raises(DatasetError, match=pattern): versioned_fs_dataset.save(hf_dataset) + + def test_exists_no_versions(self, kedro_dataset_cls, path_file): + """`exists()` returns False (not raises) when no versions are saved yet.""" + ds = kedro_dataset_cls(path=path_file, version=Version(None, None)) + assert ds.exists() is False From 37e56d9a4c515e4946091c0df4b827ef9d075ddc Mon Sep 17 00:00:00 2001 From: iwhalen Date: Mon, 11 May 2026 18:58:02 -0500 Subject: [PATCH 18/21] Fix issue with chromadb on python 3.14. Signed-off-by: iwhalen --- kedro-datasets/pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 14e12e1c3..ad0e3911c 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -210,7 +210,10 @@ yaml-yamldataset = ["kedro-datasets[pandas-base]", "PyYAML>=4.2, <7.0"] yaml = ["kedro-datasets[yaml-yamldataset]"] # Experimental Datasets -chromadb-chromadbdataset = ["chromadb>=1.0.0"] +chromadb-chromadbdataset = [ + "chromadb>=1.0.0", + "opentelemetry-exporter-otlp-proto-grpc>=1.40.0; python_version >= '3.14'", +] chromadb = ["kedro-datasets[chromadb-chromadbdataset]"] darts-torch-model-dataset = ["u8darts-all"] @@ -273,6 +276,7 @@ test = [ "adlfs~=2023.1", "biopython~=1.73", "chromadb>=1.0.0", + "opentelemetry-exporter-otlp-proto-grpc>=1.40.0; python_version >= '3.14'", "cloudpickle~=2.2.1; python_version < '3.14'", "cloudpickle>=3.1.2; python_version >= '3.14'", "compress-pickle[lz4]~=2.1.0", From 52d4f76e8f9e1026fa2319eaf6f44ad79ef5c58f Mon Sep 17 00:00:00 2001 From: iwhalen Date: Tue, 12 May 2026 07:09:15 -0500 Subject: [PATCH 19/21] Remove opik experimental changes. Signed-off-by: iwhalen --- .../opik/opik_evaluation_dataset.py | 604 ------------- .../opik/test_opik_evaluation_dataset.py | 791 ------------------ 2 files changed, 1395 deletions(-) delete mode 100644 kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py delete mode 100644 kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py diff --git a/kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py b/kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py deleted file mode 100644 index f57c869b9..000000000 --- a/kedro-datasets/kedro_datasets_experimental/opik/opik_evaluation_dataset.py +++ /dev/null @@ -1,604 +0,0 @@ -import json -import logging -import uuid -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal - -from kedro.io import AbstractDataset, DatasetError -from opik import Opik -from opik.api_objects.dataset.dataset import Dataset -from opik.rest_api.core.api_error import ApiError - -from kedro_datasets._typing import JSONPreview - -if TYPE_CHECKING: - from kedro_datasets.json import JSONDataset - from kedro_datasets.yaml import YAMLDataset - -logger = logging.getLogger(__name__) - -SUPPORTED_FILE_EXTENSIONS = {".json", ".yaml", ".yml"} -REQUIRED_OPIK_CREDENTIALS = {"api_key"} -OPTIONAL_OPIK_CREDENTIALS = {"workspace", "host", "project_name"} -VALID_SYNC_POLICIES = {"local", "remote"} -HTTP_NOT_FOUND = 404 -REQUIRED_UUID_VERSION = 7 - - -class OpikEvaluationDataset(AbstractDataset): - """Kedro dataset for Opik evaluation datasets. - - Connects to an Opik evaluation dataset and returns an ``opik.Dataset`` - on ``load()``, which can be passed to ``opik.evaluation.evaluate()`` to - run experiments. Supports an optional local JSON/YAML file as the - authoring surface for evaluation items. - - **On load / save behaviour:** - - - **On load:** Creates the remote dataset if it does not exist, - synchronises based on ``sync_policy``, and returns an ``opik.Dataset``. - - **On save:** Inserts all items to the remote dataset via Opik's - upsert-by-ID API. Items with a UUID v7 ``id`` update the existing - remote row in-place; items without a UUID v7 ``id`` create a new - remote row on every call. In ``local`` mode, items are also merged - into the local file (new items take precedence). In ``remote`` mode, - only the remote insert occurs. - - **Item format:** - - The local file and ``save()`` data must be a list of dicts. Each item - accepts the following keys: - - - ``input`` (**required**) — the evaluation input payload. - - ``id`` — identifier used for local deduplication. The upload - behaviour depends on whether ``id`` is a valid UUID v7: - - - **Valid UUID v7**: forwarded to Opik. Opik's API upserts by item - ID — the first sync creates the remote row; subsequent syncs - update that same row in-place if the content has changed. - The remote row keeps the same UUID across all syncs. Whenever - content changes, the existing remote row is updated in-place, - while no new row is created. - - **All other values** (human-readable strings, UUIDs of other - versions, ``None``, empty string, or no ``id`` key): stripped - before upload. Opik auto-generates a new UUID v7. Unchanged - content is deduplicated by content hash (no-op), but changed - content creates a **new remote row** while the previous one - remains, leading to row accumulation over time. - - - ``expected_output`` — ground-truth value for scoring. - - ``metadata`` — arbitrary metadata dict attached to the item. - - ```json - [ - { - "id": "q1", - "input": {"text": "cancel my order"}, - "expected_output": "cancel_order", - "metadata": {"source": "production"} - } - ] - ``` - ("q1" is used for local deduplication only, as it is not a UUID v7 and will be stripped on upload) - - **Sync policies:** - - - **local** (default): The local file is the source of truth. On - ``load()``, all local items are re-inserted to remote on every sync. - Opik's API upserts by item ID, so the outcome depends on whether - each item carries a UUID v7 ``id``: - - - Items with a UUID v7 ``id`` are updated in-place on the remote — - content changes replace the existing row; unchanged items are - a no-op. - - Items without a UUID v7 ``id`` (non-UUID values are stripped) - are deduplicated by content hash — unchanged content is a no-op, - but changed content creates a **new remote row** (the previous - row remains), leading to row accumulation over time. - ``save()`` inserts to remote and merges into the local file (new - data takes precedence). - - - **remote**: The remote Opik dataset is the sole source of truth. - ``load()`` fetches the remote dataset as-is with no local file - interaction. ``save()`` inserts all items to remote without writing - to any local file. If the remote dataset does not exist yet, it is - created empty — **no items are pushed from the local file**. To seed - a new remote dataset, run with ``sync_policy="local"`` at least once, - or create and populate the dataset directly via the Opik UI. - - Examples: - Using catalog YAML configuration: - - ```yaml - # Local sync policy — local file seeds and syncs to remote - evaluation_dataset: - type: kedro_datasets_experimental.opik.OpikEvaluationDataset - dataset_name: intent-detection-eval - filepath: data/evaluation/intent_items.json - sync_policy: local - credentials: opik_credentials - metadata: - project: intent-detection - - # Remote sync policy — Opik is the source of truth - production_eval: - type: kedro_datasets_experimental.opik.OpikEvaluationDataset - dataset_name: intent-detection-eval - sync_policy: remote - credentials: opik_credentials - ``` - - Using Python API: - - ```python - from kedro_datasets_experimental.opik import OpikEvaluationDataset - - dataset = OpikEvaluationDataset( - dataset_name="intent-detection-eval", - credentials={"api_key": "..."}, # pragma: allowlist secret - filepath="data/evaluation/intent_items.json", - ) - - # Load returns an opik.Dataset for running experiments - from opik.evaluation import evaluate - - eval_dataset = dataset.load() - evaluate( - dataset=eval_dataset, - task=my_task, - scoring_functions=[my_scorer], - experiment_name="my-experiment", - ) - - # Save new evaluation items - dataset.save( - [ - {"id": "q1", "input": {"text": "cancel order"}, "expected_output": "cancel"}, - ] - ) - - # Same as in the other example, "q1" is not a UUID v7 and will be stripped on upload - ``` - """ - - def __init__( - self, - dataset_name: str, - credentials: dict[str, str], - filepath: str | None = None, - sync_policy: Literal["local", "remote"] = "local", - metadata: dict[str, Any] | None = None, - ): - """Initialise ``OpikEvaluationDataset``. - - Args: - dataset_name: Name of the evaluation dataset in Opik. - credentials: Opik authentication credentials. - Required: ``api_key``. - Optional: ``workspace``, ``host``, ``project_name``. - filepath: Path to a local JSON/YAML file for authoring evaluation - items. Supports ``.json``, ``.yaml``, and ``.yml`` extensions. - When ``None``, no local file interaction occurs. - sync_policy: Controls the source of truth for reads and whether - a local file is involved: - ``"local"`` (default) — all local items are re-inserted to - remote on ``load()``; ``save()`` inserts to remote and - merges into the local file (new data takes precedence). - ``"remote"`` — ``load()`` fetches remote as-is; ``save()`` - inserts to remote without local file interaction. - metadata: Optional metadata dict stored locally and returned by - ``_describe()``. Note: Opik's ``create_dataset()`` does not - accept a metadata argument, so this value is not propagated - to the remote dataset. - """ - self._validate_init_params(credentials, filepath, sync_policy) - - self._dataset_name = dataset_name - self._filepath = Path(filepath) if filepath else None - self._sync_policy = sync_policy - self._metadata = metadata - self._file_dataset = None - - try: - self._client = Opik(**credentials) - except Exception as e: - raise DatasetError(f"Failed to initialise Opik client: {e}") from e - - @staticmethod - def _validate_init_params( - credentials: dict[str, str], - filepath: str | None, - sync_policy: str, - ) -> None: - OpikEvaluationDataset._validate_credentials(credentials) - OpikEvaluationDataset._validate_sync_policy(sync_policy) - OpikEvaluationDataset._validate_filepath(filepath) - - @staticmethod - def _validate_credentials(credentials: dict[str, str]) -> None: - for key in REQUIRED_OPIK_CREDENTIALS: - if key not in credentials: - raise DatasetError( - f"Missing required Opik credential: '{key}'." - ) - if not credentials[key] or not str(credentials[key]).strip(): - raise DatasetError( - f"Opik credential '{key}' cannot be empty." - ) - for key in OPTIONAL_OPIK_CREDENTIALS: - if key in credentials and ( - not credentials[key] or not str(credentials[key]).strip() - ): - raise DatasetError( - f"Opik credential '{key}' cannot be empty if provided." - ) - - @staticmethod - def _validate_sync_policy(sync_policy: str) -> None: - if sync_policy not in VALID_SYNC_POLICIES: - raise DatasetError( - f"Invalid sync_policy '{sync_policy}'. " - f"Must be one of: {', '.join(sorted(VALID_SYNC_POLICIES))}." - ) - - @staticmethod - def _validate_filepath(filepath: str | None) -> None: - if filepath is None: - return - suffix = Path(filepath).suffix.lower() - if suffix not in SUPPORTED_FILE_EXTENSIONS: - raise DatasetError( - f"Unsupported file extension '{suffix}'. " - f"Supported formats: {', '.join(sorted(SUPPORTED_FILE_EXTENSIONS))}." - ) - - @property - def file_dataset(self) -> "JSONDataset | YAMLDataset": - """Return a JSON or YAML file dataset based on the filepath extension.""" - if not self._filepath: - raise DatasetError("filepath must be provided for file dataset operations.") - if self._file_dataset is None: - suffix = self._filepath.suffix.lower() - if suffix in (".yaml", ".yml"): - from kedro_datasets.yaml import YAMLDataset # noqa: PLC0415 - self._file_dataset = YAMLDataset(filepath=str(self._filepath)) - else: - from kedro_datasets.json import JSONDataset # noqa: PLC0415 - self._file_dataset = JSONDataset(filepath=str(self._filepath)) - return self._file_dataset - - def _get_or_create_remote_dataset(self) -> Dataset: - """Ensure the remote Opik dataset exists, creating it if not found. - - Returns the latest ``Dataset`` object. - - Raises: - DatasetError: If the Opik API returns an unexpected error or is - unreachable. - """ - try: - return self._client.get_dataset(name=self._dataset_name) - except ApiError as e: - if e.status_code != HTTP_NOT_FOUND: - raise DatasetError( - f"Opik API error while fetching dataset '{self._dataset_name}': {e}" - ) from e - except Exception as e: - raise DatasetError( - f"Failed to connect to Opik while fetching dataset " - f"'{self._dataset_name}': {e}" - ) from e - - try: - logger.info( - "Dataset '%s' not found on Opik, creating it.", - self._dataset_name, - ) - return self._client.create_dataset( - name=self._dataset_name, - description=f"Created by Kedro (sync_policy={self._sync_policy})", - ) - except ApiError as e: - raise DatasetError( - f"Opik API error while creating dataset '{self._dataset_name}': {e}" - ) from e - except Exception as e: - raise DatasetError( - f"Failed to connect to Opik while creating dataset " - f"'{self._dataset_name}': {e}" - ) from e - - @staticmethod - def _strip_id(item: dict[str, Any]) -> dict[str, Any]: - return {k: v for k, v in item.items() if k != "id"} - - @staticmethod - def _validate_items(items: list[dict[str, Any]]) -> None: - """Validate that all items contain the required ``input`` key. - - Raises: - DatasetError: If any item is missing the ``input`` key. - """ - for i, item in enumerate(items): - if "input" not in item: - raise DatasetError( - f"Dataset item at index {i} is missing required 'input' key." - ) - - def _upload_items(self, dataset: Dataset, items: list[dict[str, Any]]) -> None: - """Insert items into the remote Opik dataset. - - Upload behaviour depends on whether an item carries a UUID v7 ``id``: - - - **Valid UUID v7**: forwarded to Opik. Opik's REST API calls - ``create_or_update`` by item ID — the first call creates the - remote row; subsequent calls update that same row in-place if - the content has changed. Whenever content changes, the existing - remote row is updated in-place, while no new row is created. - - **All other values** (human-readable strings, UUIDs of other - versions, ``None``, empty string, or no ``id`` key): stripped - before upload. Opik auto-generates a new UUID v7. Unchanged - content is deduplicated by content hash (no-op), but changed - content creates a **new remote row** while the previous one - remains. - - Callers are responsible for validating items before calling this method. - - Raises: - DatasetError: If the Opik API returns an error or the server is - unreachable during insert. - """ - items_to_insert = [] - for item in items: - if "id" not in item: - items_to_insert.append(item) - elif not item["id"]: - items_to_insert.append(self._strip_id(item)) - else: - try: - parsed = uuid.UUID(str(item["id"])) - if parsed.version == REQUIRED_UUID_VERSION: - items_to_insert.append(item) # valid UUID v7 — preserve id - else: - items_to_insert.append(self._strip_id(item)) - except ValueError: - items_to_insert.append(self._strip_id(item)) - try: - dataset.insert(items_to_insert) - except ApiError as e: - raise DatasetError( - f"Opik API error while inserting items into dataset " - f"'{self._dataset_name}': {e}" - ) from e - except Exception as e: - raise DatasetError( - f"Failed to insert items into Opik dataset '{self._dataset_name}': {e}" - ) from e - - def _sync_local_to_remote(self, dataset: Dataset) -> Dataset: - """Insert all local items into the remote dataset. - - Reads the local file and inserts all items into the remote dataset. - The Opik SDK deduplicates by content hash, so re-inserting unchanged - items is a no-op. Returns a refreshed ``Dataset`` object. If the dataset's - id is a valid UUID v7, the same remote row is updated in-place on every sync. - Otherwise, a new remote row will be created. - """ - if not self._filepath or not self._filepath.exists(): - return dataset - - local_items = self.file_dataset.load() - self._validate_items(local_items) - - if not local_items: - return dataset - - items_without_stable_id = [ - item for item in local_items - if "id" not in item or not item.get("id") - ] - if items_without_stable_id: - logger.warning( - "Found %d item(s) with a missing, None, or empty 'id' field in '%s'. " - "These cannot be tracked across syncs and will create new remote " - "rows on every load.", - len(items_without_stable_id), - self._filepath, - ) - - items_with_non_uuid_v7_id = [] - for item in local_items: - if item.get("id"): # present and non-empty/non-None - try: - parsed = uuid.UUID(str(item["id"])) - if parsed.version != REQUIRED_UUID_VERSION: - items_with_non_uuid_v7_id.append(item) - except ValueError: - items_with_non_uuid_v7_id.append(item) - if items_with_non_uuid_v7_id: - logger.warning( - "Found %d item(s) with non-UUID-v7 'id' values in '%s' " - "(e.g. '%s'). Opik requires UUID v7 for item IDs — these " - "will be stripped before upload and Opik will auto-generate " - "UUID v7 values. Remote rows will not have stable identities.", - len(items_with_non_uuid_v7_id), - self._filepath, - items_with_non_uuid_v7_id[0]["id"], - ) - - logger.info( - "Syncing %d item(s) from '%s' to remote dataset '%s'.", - len(local_items), - self._filepath, - self._dataset_name, - ) - self._upload_items(dataset, local_items) - try: - self._client.flush() - except Exception as e: - raise DatasetError( - f"Failed to flush items to Opik dataset '{self._dataset_name}': {e}" - ) from e - - try: - return self._client.get_dataset(name=self._dataset_name) - except ApiError as e: - raise DatasetError( - f"Opik API error while refreshing dataset '{self._dataset_name}' after sync: {e}" - ) from e - except Exception as e: - raise DatasetError( - f"Failed to refresh dataset '{self._dataset_name}' after sync: {e}" - ) from e - - @staticmethod - def _merge_items( - existing: list[dict[str, Any]], - new: list[dict[str, Any]], - ) -> list[dict[str, Any]]: - """Merge new items into an existing list, deduplicating by ``id``. - - Items without an ``id`` key are always appended. For items with an - ``id``, new items take precedence — existing entries with the same - ``id`` are replaced in place. - """ - new_by_id: dict[str, dict[str, Any]] = { - item["id"]: item for item in new if "id" in item - } - - seen_ids: set[str] = set() - merged: list[dict[str, Any]] = [] - - for item in existing: - item_id = item.get("id") - if item_id is not None and item_id in new_by_id: - merged.append(new_by_id[item_id]) - seen_ids.add(item_id) - else: - merged.append(item) - if item_id is not None: - seen_ids.add(item_id) - - for item in new: - item_id = item.get("id") - if item_id is not None and item_id in seen_ids: - continue - if item_id is not None: - seen_ids.add(item_id) - merged.append(item) - - return merged - - def load(self) -> Dataset: - """Load the Opik dataset, syncing local items to remote if sync_policy is ``local``. - - Creates the remote dataset if it does not exist. In ``local`` mode, all - local items are re-inserted to remote on every load via Opik's - ``create_or_update`` API (upsert by item ID). On items with a valid UUID v7 - ``id``, update the existing remote row in-place, and no new row is created. - On items where the ``id`` is not a valid UUID v7 (including missing, ``None``, or empty), - the ``id`` is stripped before upload and Opik auto-generates a new UUID v7. - Unchanged content is deduplicated (no-op), but changed content creates a - new remote row while the previous one remains. - - Returns: - Dataset: The Opik dataset ready for use in experiments. - - Raises: - DatasetError: If the Opik API returns an unexpected error or the - server is unreachable. - """ - dataset = self._get_or_create_remote_dataset() - - if self._sync_policy == "local": - dataset = self._sync_local_to_remote(dataset) - - logger.info( - "Loaded dataset '%s' (sync_policy='%s').", - self._dataset_name, - self._sync_policy, - ) - return dataset - - def save(self, data: list[dict[str, Any]]) -> None: - """Insert items into the Opik dataset and optionally update the local file. - - In ``remote`` mode, only the remote upload occurs. In ``local`` mode, - items are also merged into the local file. - - Args: - data: List of dicts, each containing at least an ``input`` key. - - Raises: - DatasetError: If the Opik API call fails or any item is missing ``input``. - """ - if self._sync_policy == "remote": - logger.warning( - "sync_policy='remote': save() uploads to remote only, " - "local file '%s' will not be updated.", - self._filepath, - ) - - self._validate_items(data) - - dataset = self._get_or_create_remote_dataset() - self._upload_items(dataset, data) - try: - self._client.flush() - except Exception as e: - raise DatasetError( - f"Failed to flush items to Opik dataset '{self._dataset_name}': {e}" - ) from e - - if self._sync_policy == "local" and self._filepath: - existing: list[dict] = [] - if self._filepath.exists(): - existing = self.file_dataset.load() - self.file_dataset.save(self._merge_items(existing, data)) - - def _exists(self) -> bool: - try: - self._client.get_dataset(name=self._dataset_name) - return True - except ApiError as e: - if e.status_code == HTTP_NOT_FOUND: - return False - raise DatasetError( - f"Opik API error while checking dataset '{self._dataset_name}': {e}" - ) from e - except Exception as e: - raise DatasetError( - f"Failed to connect to Opik while checking dataset " - f"'{self._dataset_name}': {e}" - ) from e - - def _describe(self) -> dict[str, Any]: - return { - "dataset_name": self._dataset_name, - "filepath": str(self._filepath) if self._filepath else None, - "sync_policy": self._sync_policy, - "metadata": self._metadata, - } - - def preview(self) -> JSONPreview: - """Generate a JSON-compatible preview of the local evaluation data for Kedro-Viz. - - Returns: - JSONPreview: A Kedro-Viz-compatible object containing a serialized JSON string. - Returns a descriptive message if filepath is not configured or does not exist. - """ - if not self._filepath: - return JSONPreview("No filepath configured.") - - if not self._filepath.exists(): - return JSONPreview("Local evaluation dataset does not exist.") - - local_data = self.file_dataset.load() - - if isinstance(local_data, str): - local_data = {"content": local_data} - - try: - return JSONPreview(json.dumps(local_data)) - except (TypeError, ValueError) as e: - return JSONPreview(f"Could not serialise local data to JSON: {e}") diff --git a/kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py deleted file mode 100644 index a2de5639a..000000000 --- a/kedro-datasets/kedro_datasets_experimental/tests/opik/test_opik_evaluation_dataset.py +++ /dev/null @@ -1,791 +0,0 @@ -import datetime -import json -from unittest.mock import Mock, patch - -import pytest -import yaml -from kedro.io import DatasetError -from opik.rest_api.core.api_error import ApiError - -from kedro_datasets_experimental.opik.opik_evaluation_dataset import ( - OpikEvaluationDataset, -) - - -def make_api_error(status_code: int) -> ApiError: - """Return an ApiError with the given status code.""" - return ApiError(status_code=status_code, headers={}, body={}) - - -@pytest.fixture -def mock_opik(): - """Mock Opik client instance.""" - with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.Opik") as mock_class: - instance = Mock() - mock_class.return_value = instance - yield instance - - -@pytest.fixture -def mock_credentials(): - """Valid Opik credentials for testing.""" - return { - "api_key": "opik_test_key", # pragma: allowlist secret - "workspace": "test-workspace", - } - - -@pytest.fixture -def eval_items(): - """Sample evaluation dataset items with human-readable (non-UUID) IDs.""" - return [ - { - "id": "item_001", - "input": {"question": "What is AI?"}, - "expected_output": {"answer": "Artificial Intelligence"}, - }, - { - "id": "item_002", - "input": {"question": "What is ML?"}, - "expected_output": {"answer": "Machine Learning"}, - }, - ] - - -@pytest.fixture -def eval_items_uuid(): - """Sample evaluation dataset items with valid UUID v7 IDs.""" - return [ - { - "id": "018e2f1a-dead-7abc-8def-000000000001", - "input": {"question": "What is AI?"}, - "expected_output": {"answer": "Artificial Intelligence"}, - }, - { - "id": "018e2f1a-dead-7abc-8def-000000000002", - "input": {"question": "What is ML?"}, - "expected_output": {"answer": "Machine Learning"}, - }, - ] - - -@pytest.fixture -def eval_items_mixed(): - """Items with a mix of UUID v7 and human-readable IDs.""" - return [ - { - "id": "018e2f1a-dead-7abc-8def-000000000001", - "input": {"question": "What is AI?"}, - "expected_output": {"answer": "Artificial Intelligence"}, - }, - { - "id": "human_readable_id", - "input": {"question": "What is ML?"}, - "expected_output": {"answer": "Machine Learning"}, - }, - ] - - -@pytest.fixture -def eval_items_no_id(): - """Evaluation items without IDs.""" - return [ - {"input": {"question": "What is AI?"}, "expected_output": {"answer": "AI"}}, - {"input": {"question": "What is ML?"}, "expected_output": {"answer": "ML"}}, - ] - - -@pytest.fixture -def filepath_json(tmp_path, eval_items): - """Temporary JSON file with evaluation items.""" - filepath = tmp_path / "eval.json" - filepath.write_text(json.dumps(eval_items)) - return str(filepath) - - -@pytest.fixture -def filepath_yaml(tmp_path, eval_items): - """Temporary YAML file with evaluation items.""" - filepath = tmp_path / "eval.yaml" - filepath.write_text(yaml.dump(eval_items)) - return str(filepath) - - -@pytest.fixture -def mock_remote_dataset(): - """Mock Opik Dataset object.""" - ds = Mock() - ds.name = "test-dataset" - return ds - - -@pytest.fixture -def dataset_local(filepath_json, mock_credentials, mock_opik, mock_remote_dataset): - """OpikEvaluationDataset with local sync policy.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - return OpikEvaluationDataset( - dataset_name="test-dataset", - credentials=mock_credentials, - filepath=filepath_json, - sync_policy="local", - ) - - -@pytest.fixture -def dataset_remote(mock_credentials, mock_opik, mock_remote_dataset): - """OpikEvaluationDataset with remote sync policy and no filepath.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - return OpikEvaluationDataset( - dataset_name="test-dataset", - credentials=mock_credentials, - sync_policy="remote", - ) - - -class TestOpikEvaluationDatasetInit: - """Test OpikEvaluationDataset initialisation.""" - - def test_init_minimal_params(self, mock_credentials, mock_opik): - """Minimal required params store expected defaults.""" - ds = OpikEvaluationDataset( - dataset_name="my-dataset", - credentials=mock_credentials, - ) - assert ds._dataset_name == "my-dataset" - assert ds._filepath is None - assert ds._sync_policy == "local" - assert ds._metadata is None - - def test_init_all_params(self, filepath_json, mock_credentials, mock_opik): - """All params are stored correctly.""" - meta = {"project": "test"} - ds = OpikEvaluationDataset( - dataset_name="my-dataset", - credentials=mock_credentials, - filepath=filepath_json, - sync_policy="remote", - metadata=meta, - ) - assert ds._sync_policy == "remote" - assert ds._metadata == meta - assert ds._filepath is not None - - def test_init_missing_api_key(self, mock_opik): - """Missing api_key raises DatasetError.""" - with pytest.raises(DatasetError, match="Missing required Opik credential: 'api_key'"): - OpikEvaluationDataset( - dataset_name="ds", - credentials={"workspace": "w"}, - ) - - @pytest.mark.parametrize("empty_value", ["", " "]) - def test_init_empty_api_key(self, mock_opik, empty_value): - """Empty api_key raises DatasetError.""" - with pytest.raises(DatasetError, match="Opik credential 'api_key' cannot be empty"): - OpikEvaluationDataset( - dataset_name="ds", - credentials={"api_key": empty_value}, - ) - - def test_init_empty_optional_credential(self, mock_opik): - """Empty optional credential (workspace) raises DatasetError.""" - with pytest.raises(DatasetError, match="Opik credential 'workspace' cannot be empty if provided"): - OpikEvaluationDataset( - dataset_name="ds", - credentials={"api_key": "key", "workspace": ""}, # pragma: allowlist secret - ) - - def test_init_invalid_sync_policy(self, mock_credentials, mock_opik): - """Invalid sync_policy raises DatasetError.""" - with pytest.raises(DatasetError, match="Invalid sync_policy 'invalid'"): - OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - sync_policy="invalid", - ) - - def test_init_unsupported_filepath_extension(self, tmp_path, mock_credentials, mock_opik): - """Unsupported file extension raises DatasetError.""" - bad_file = tmp_path / "data.txt" - bad_file.write_text("content") - with pytest.raises(DatasetError, match="Unsupported file extension '.txt'"): - OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=str(bad_file), - ) - - def test_init_client_failure_raises_dataset_error(self, mock_credentials): - """Opik client construction failure is wrapped in DatasetError.""" - with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.Opik") as mock_class: - mock_class.side_effect = Exception("Connection refused") - with pytest.raises(DatasetError, match="Failed to initialise Opik client"): - OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - ) - - -class TestFiledatasetProperty: - """Test the file_dataset lazy property.""" - - def test_json_returns_json_dataset(self, dataset_local): - """JSON filepath resolves to JSONDataset.""" - assert dataset_local.file_dataset.__class__.__name__ == "JSONDataset" - - def test_yaml_returns_yaml_dataset(self, filepath_yaml, mock_credentials, mock_opik, mock_remote_dataset): - """YAML filepath resolves to YAMLDataset.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=filepath_yaml, - ) - assert ds.file_dataset.__class__.__name__ == "YAMLDataset" - - def test_is_cached(self, dataset_local): - """Repeated access returns the same object.""" - assert dataset_local.file_dataset is dataset_local.file_dataset - - def test_no_filepath_raises(self, dataset_remote): - """Accessing file_dataset without a filepath raises DatasetError.""" - with pytest.raises(DatasetError, match="filepath must be provided"): - _ = dataset_remote.file_dataset - - -class TestGetOrCreateRemoteDataset: - """Test the _get_or_create_remote_dataset helper.""" - - def test_returns_existing_dataset(self, dataset_local, mock_opik, mock_remote_dataset): - """Returns the dataset when it already exists.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - result = dataset_local._get_or_create_remote_dataset() - assert result is mock_remote_dataset - mock_opik.create_dataset.assert_not_called() - - def test_creates_dataset_on_404(self, dataset_local, mock_opik, mock_remote_dataset): - """Creates a new dataset when get_dataset raises 404.""" - mock_opik.get_dataset.side_effect = make_api_error(404) - mock_opik.create_dataset.return_value = mock_remote_dataset - - result = dataset_local._get_or_create_remote_dataset() - - mock_opik.create_dataset.assert_called_once() - assert result is mock_remote_dataset - - def test_non_404_api_error_raises_dataset_error(self, dataset_local, mock_opik): - """Non-404 API error from get_dataset is wrapped in DatasetError.""" - mock_opik.get_dataset.side_effect = make_api_error(500) - with pytest.raises(DatasetError, match="Opik API error while fetching dataset"): - dataset_local._get_or_create_remote_dataset() - - def test_create_dataset_api_error_raises_dataset_error(self, dataset_local, mock_opik): - """API error during create_dataset is wrapped in DatasetError.""" - mock_opik.get_dataset.side_effect = make_api_error(404) - mock_opik.create_dataset.side_effect = make_api_error(400) - with pytest.raises(DatasetError, match="Opik API error while creating dataset"): - dataset_local._get_or_create_remote_dataset() - - def test_connection_error_on_get_raises_dataset_error(self, dataset_local, mock_opik): - """Non-ApiError on get_dataset (e.g. connection refused) is wrapped in DatasetError.""" - mock_opik.get_dataset.side_effect = ConnectionRefusedError("Connection refused") - with pytest.raises(DatasetError, match="Failed to connect to Opik"): - dataset_local._get_or_create_remote_dataset() - - def test_connection_error_on_create_raises_dataset_error(self, dataset_local, mock_opik): - """Non-ApiError on create_dataset (e.g. connection refused) is wrapped in DatasetError.""" - mock_opik.get_dataset.side_effect = make_api_error(404) - mock_opik.create_dataset.side_effect = ConnectionRefusedError("Connection refused") - with pytest.raises(DatasetError, match="Failed to connect to Opik"): - dataset_local._get_or_create_remote_dataset() - - - -class TestValidateItems: - """Test the _validate_items static method.""" - - def test_valid_items_pass(self, eval_items): - """Items with 'input' keys pass validation without error.""" - OpikEvaluationDataset._validate_items(eval_items) # no exception - - def test_empty_list_passes(self): - """Empty item list is valid.""" - OpikEvaluationDataset._validate_items([]) - - def test_missing_input_raises_dataset_error(self): - """Item missing 'input' raises DatasetError with index.""" - items = [{"input": {"q": "ok"}}, {"expected_output": "missing input"}] - with pytest.raises(DatasetError, match="index 1.*missing required 'input'"): - OpikEvaluationDataset._validate_items(items) - - -class TestUploadItems: - """Test the _upload_items method.""" - - @pytest.mark.parametrize("bad_id,label", [ - ("item_001", "human-readable"), - ("550e8400-e29b-41d4-a716-446655440000", "UUID v4"), - ]) - def test_non_uuid_v7_id_is_stripped(self, dataset_local, mock_remote_dataset, bad_id, label): - """Non-UUID-v7 IDs (human-readable or other UUID versions) are stripped before upload.""" - items = [{"id": bad_id, "input": {"question": "What is AI?"}}] - dataset_local._upload_items(mock_remote_dataset, items) - - inserted = mock_remote_dataset.insert.call_args[0][0] - assert "id" not in inserted[0] - - def test_non_id_fields_are_preserved(self, dataset_local, mock_remote_dataset, eval_items): - """input and expected_output fields are passed through unchanged.""" - dataset_local._upload_items(mock_remote_dataset, eval_items) - - inserted = mock_remote_dataset.insert.call_args[0][0] - assert inserted[0]["input"] == eval_items[0]["input"] - assert inserted[0]["expected_output"] == eval_items[0]["expected_output"] - - def test_valid_uuidv7_ids_are_preserved(self, dataset_local, mock_remote_dataset, eval_items_uuid): - """Valid UUID IDs are forwarded to Opik unchanged.""" - dataset_local._upload_items(mock_remote_dataset, eval_items_uuid) - - inserted = mock_remote_dataset.insert.call_args[0][0] - assert inserted[0]["id"] == eval_items_uuid[0]["id"] - assert inserted[1]["id"] == eval_items_uuid[1]["id"] - - def test_mixed_ids_uuid_preserved_non_uuid_stripped( - self, dataset_local, mock_remote_dataset, eval_items_mixed - ): - """UUID IDs are preserved; human-readable IDs are stripped in the same batch.""" - dataset_local._upload_items(mock_remote_dataset, eval_items_mixed) - - inserted = mock_remote_dataset.insert.call_args[0][0] - assert inserted[0]["id"] == eval_items_mixed[0]["id"] # UUID preserved - assert "id" not in inserted[1] # non-UUID stripped - - def test_uuid_v7_id_preserved_when_content_changes(self, dataset_local, mock_remote_dataset): - """A UUID v7 id is forwarded on both uploads even when item content changes.""" - uuid_v7_id = "018e2f1a-dead-7abc-8def-000000000001" - first_version = [{"id": uuid_v7_id, "input": {"question": "What is AI?"}}] - second_version = [{"id": uuid_v7_id, "input": {"question": "What is Artificial Intelligence?"}}] - - dataset_local._upload_items(mock_remote_dataset, first_version) - assert mock_remote_dataset.insert.call_args[0][0][0]["id"] == uuid_v7_id - - dataset_local._upload_items(mock_remote_dataset, second_version) - assert mock_remote_dataset.insert.call_args[0][0][0]["id"] == uuid_v7_id - - def test_items_without_id_are_passed_unchanged(self, dataset_local, mock_remote_dataset, eval_items_no_id): - """Items that have no 'id' key are inserted as-is.""" - dataset_local._upload_items(mock_remote_dataset, eval_items_no_id) - - inserted = mock_remote_dataset.insert.call_args[0][0] - assert inserted == eval_items_no_id - - @pytest.mark.parametrize("bad_id", [None, ""]) - def test_none_or_empty_id_is_stripped(self, dataset_local, mock_remote_dataset, bad_id): - """Items with id=None or id='' have the id key stripped before upload.""" - items = [{"id": bad_id, "input": {"question": "What is AI?"}}] - dataset_local._upload_items(mock_remote_dataset, items) - - inserted = mock_remote_dataset.insert.call_args[0][0] - assert "id" not in inserted[0] - - @pytest.mark.parametrize("error,match", [ - (make_api_error(500), "Opik API error while inserting items"), - (ConnectionRefusedError("Connection refused"), "Failed to insert items into Opik dataset"), - ]) - def test_insert_error_raises_dataset_error(self, dataset_local, mock_remote_dataset, eval_items, error, match): - """SDK errors from dataset.insert() are wrapped in DatasetError.""" - mock_remote_dataset.insert.side_effect = error - with pytest.raises(DatasetError, match=match): - dataset_local._upload_items(mock_remote_dataset, eval_items) - - -class TestSyncLocalToRemote: - """Test the _sync_local_to_remote helper.""" - - def test_returns_dataset_unchanged_when_no_filepath(self, dataset_remote, mock_remote_dataset): - """No-op when filepath is not configured.""" - result = dataset_remote._sync_local_to_remote(mock_remote_dataset) - assert result is mock_remote_dataset - - def test_returns_dataset_unchanged_when_file_missing(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): - """No-op when local file does not exist.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=str(tmp_path / "nonexistent.json"), - ) - result = ds._sync_local_to_remote(mock_remote_dataset) - assert result is mock_remote_dataset - - def test_returns_dataset_unchanged_for_empty_file(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): - """No-op when local file contains an empty list.""" - empty_file = tmp_path / "empty.json" - empty_file.write_text("[]") - mock_opik.get_dataset.return_value = mock_remote_dataset - - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=str(empty_file), - ) - result = ds._sync_local_to_remote(mock_remote_dataset) - assert result is mock_remote_dataset - mock_remote_dataset.insert.assert_not_called() - - def test_calls_upload_items(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): - """Loads local items and passes them to _upload_items.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - - with patch.object(dataset_local, "_upload_items") as mock_upload: - dataset_local._sync_local_to_remote(mock_remote_dataset) - mock_upload.assert_called_once_with(mock_remote_dataset, eval_items) - - def test_returns_refreshed_dataset(self, dataset_local, mock_opik, mock_remote_dataset): - """Returns the result of a fresh get_dataset call after upload.""" - refreshed = Mock() - mock_opik.get_dataset.return_value = refreshed - - with patch.object(dataset_local, "_upload_items"): - result = dataset_local._sync_local_to_remote(mock_remote_dataset) - - assert result is refreshed - - def test_flushes_client_after_upload(self, dataset_local, mock_opik, mock_remote_dataset): - """Calls client.flush() after insert to ensure items are committed before evaluate().""" - mock_opik.get_dataset.return_value = mock_remote_dataset - - with patch.object(dataset_local, "_upload_items"): - dataset_local._sync_local_to_remote(mock_remote_dataset) - - mock_opik.flush.assert_called_once() - - def test_flush_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): - """Errors from client.flush() during sync are wrapped in DatasetError.""" - mock_opik.flush.side_effect = Exception("flush failed") - mock_opik.get_dataset.return_value = mock_remote_dataset - - with patch.object(dataset_local, "_upload_items"): - with pytest.raises(DatasetError, match="Failed to flush items"): - dataset_local._sync_local_to_remote(mock_remote_dataset) - - def test_refresh_api_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): - """ApiError from get_dataset() after sync is wrapped in DatasetError.""" - mock_opik.flush.return_value = None - mock_opik.get_dataset.side_effect = make_api_error(500) - - with patch.object(dataset_local, "_upload_items"): - with pytest.raises(DatasetError, match="Opik API error while refreshing dataset"): - dataset_local._sync_local_to_remote(mock_remote_dataset) - - def test_refresh_connection_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): - """Connection error from get_dataset() after sync is wrapped in DatasetError.""" - mock_opik.flush.return_value = None - mock_opik.get_dataset.side_effect = ConnectionRefusedError("Connection refused") - - with patch.object(dataset_local, "_upload_items"): - with pytest.raises(DatasetError, match="Failed to refresh dataset"): - dataset_local._sync_local_to_remote(mock_remote_dataset) - - @pytest.mark.parametrize("items", [ - [{"input": {"question": "What is AI?"}}], - [{"id": None, "input": {"question": "What is AI?"}}], - [{"id": "", "input": {"question": "What is AI?"}}], - ]) - def test_warns_when_id_missing_or_empty( - self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset, items - ): - """Logs a warning when items have no id, id=None, or id=''.""" - filepath = tmp_path / "eval.json" - filepath.write_text(json.dumps(items)) - mock_opik.get_dataset.return_value = mock_remote_dataset - - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=str(filepath), - ) - - with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.logger") as mock_logger: - with patch.object(ds, "_upload_items"): - ds._sync_local_to_remote(mock_remote_dataset) - warning_messages = [c[0][0] for c in mock_logger.warning.call_args_list] - assert any("missing, None, or empty" in msg for msg in warning_messages) - - -class TestMergeItems: - """Test the _merge_items static method.""" - - def test_new_item_replaces_existing_by_id(self): - """New item with existing ID replaces the old entry in place.""" - existing = [{"id": "a", "input": {"v": 1}}, {"id": "b", "input": {"v": 2}}] - new = [{"id": "a", "input": {"v": 99}}] - result = OpikEvaluationDataset._merge_items(existing, new) - assert result[0]["input"]["v"] == 99 - assert len(result) == 2 - - def test_new_item_without_id_is_appended(self): - """New item without ID is always appended, never deduped.""" - existing = [{"id": "a", "input": {"v": 1}}] - new = [{"input": {"v": 2}}] - result = OpikEvaluationDataset._merge_items(existing, new) - assert len(result) == 2 - assert result[1]["input"]["v"] == 2 - - def test_new_item_with_new_id_is_appended(self): - """New item with a novel ID is appended after existing items.""" - existing = [{"id": "a", "input": {"v": 1}}] - new = [{"id": "b", "input": {"v": 2}}] - result = OpikEvaluationDataset._merge_items(existing, new) - assert len(result) == 2 - assert result[1]["id"] == "b" - - def test_empty_existing_returns_new(self): - """Merging into empty list returns a copy of new items.""" - new = [{"id": "a", "input": {"v": 1}}] - result = OpikEvaluationDataset._merge_items([], new) - assert result == new - - def test_empty_new_returns_existing(self): - """Merging empty new list returns existing unchanged.""" - existing = [{"id": "a", "input": {"v": 1}}] - result = OpikEvaluationDataset._merge_items(existing, []) - assert result == existing - - def test_order_preserved_with_replacement(self): - """Replacement keeps the item at its original position.""" - existing = [{"id": "a", "input": {"v": 1}}, {"id": "b", "input": {"v": 2}}] - new = [{"id": "b", "input": {"v": 99}}] - result = OpikEvaluationDataset._merge_items(existing, new) - assert result[0]["id"] == "a" - assert result[1]["input"]["v"] == 99 - - def test_duplicate_no_id_items_both_appended(self): - """Two new items without ID are both appended (no dedup possible).""" - existing = [] - new = [{"input": {"v": 1}}, {"input": {"v": 1}}] - result = OpikEvaluationDataset._merge_items(existing, new) - assert len(result) == 2 - - -class TestLoad: - """Test the load() method.""" - - def test_load_remote_mode_returns_dataset(self, dataset_remote, mock_opik, mock_remote_dataset): - """Remote mode fetches and returns the dataset without syncing.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - result = dataset_remote.load() - assert result is mock_remote_dataset - - def test_load_remote_mode_does_not_sync(self, dataset_remote, mock_opik, mock_remote_dataset): - """Remote mode does not call _sync_local_to_remote.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - with patch.object(dataset_remote, "_sync_local_to_remote") as mock_sync: - dataset_remote.load() - mock_sync.assert_not_called() - - def test_load_local_mode_calls_sync(self, dataset_local, mock_opik, mock_remote_dataset): - """Local mode calls _sync_local_to_remote.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - with patch.object(dataset_local, "_sync_local_to_remote", return_value=mock_remote_dataset) as mock_sync: - dataset_local.load() - mock_sync.assert_called_once_with(mock_remote_dataset) - - def test_load_creates_dataset_if_missing(self, dataset_local, mock_opik, mock_remote_dataset): - """Creates the remote dataset if it does not exist.""" - mock_opik.get_dataset.side_effect = [make_api_error(404), mock_remote_dataset] - mock_opik.create_dataset.return_value = mock_remote_dataset - - with patch.object(dataset_local, "_sync_local_to_remote", return_value=mock_remote_dataset): - result = dataset_local.load() - - mock_opik.create_dataset.assert_called_once() - assert result is mock_remote_dataset - - def test_load_api_error_raises_dataset_error(self, dataset_local, mock_opik): - """Non-404 API error from load is wrapped in DatasetError.""" - mock_opik.get_dataset.side_effect = make_api_error(503) - with pytest.raises(DatasetError, match="Opik API error while fetching dataset"): - dataset_local.load() - - -class TestSave: - """Test the save() method.""" - - def test_save_local_mode_uploads_to_remote(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): - """Local mode uploads items to the remote dataset.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - dataset_local.save(eval_items) - mock_remote_dataset.insert.assert_called_once() - - def test_save_flushes_client_after_upload(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): - """Calls client.flush() after insert to ensure items are committed before evaluate().""" - mock_opik.get_dataset.return_value = mock_remote_dataset - dataset_local.save(eval_items) - mock_opik.flush.assert_called_once() - - def test_save_flush_error_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): - """Errors from client.flush() during save are wrapped in DatasetError.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - mock_opik.flush.side_effect = Exception("flush failed") - - with pytest.raises(DatasetError, match="Failed to flush items"): - dataset_local.save(eval_items) - - def test_save_local_mode_merges_into_file(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): - """Local mode merges new items into the local file.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - new_item = [{"id": "item_003", "input": {"question": "What is DL?"}}] - dataset_local.save(new_item) - - written = json.loads(dataset_local._filepath.read_text()) - ids = [i.get("id") for i in written] - assert "item_003" in ids - - def test_save_local_mode_replaces_existing_id(self, dataset_local, mock_opik, mock_remote_dataset, eval_items): - """Local mode replaces an existing item when IDs match.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - updated = [{"id": "item_001", "input": {"question": "Updated?"}}] - dataset_local.save(updated) - - written = json.loads(dataset_local._filepath.read_text()) - item_001 = next(i for i in written if i.get("id") == "item_001") - assert item_001["input"]["question"] == "Updated?" - - def test_save_local_mode_creates_file_if_missing(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): - """Creates the local file if it does not exist yet.""" - missing = tmp_path / "new.json" - mock_opik.get_dataset.return_value = mock_remote_dataset - - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=str(missing), - sync_policy="local", - ) - ds.save([{"id": "x", "input": {"q": "hello"}}]) - assert missing.exists() - - def test_save_remote_mode_uploads_to_remote(self, dataset_remote, mock_opik, mock_remote_dataset, eval_items): - """Remote mode uploads items to the remote dataset.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - dataset_remote.save(eval_items) - mock_remote_dataset.insert.assert_called_once() - - def test_save_remote_mode_does_not_write_local_file(self, dataset_remote, mock_opik, mock_remote_dataset, eval_items): - """Remote mode does not create or modify a local file.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - dataset_remote.save(eval_items) - assert dataset_remote._filepath is None - - def test_save_remote_mode_logs_warning(self, dataset_remote, mock_opik, mock_remote_dataset, eval_items): - """Remote mode logs a warning that the local file won't be updated.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - with patch("kedro_datasets_experimental.opik.opik_evaluation_dataset.logger") as mock_logger: - dataset_remote.save(eval_items) - warning_messages = [c[0][0] for c in mock_logger.warning.call_args_list] - assert any("uploads to remote only" in msg for msg in warning_messages) - - def test_save_missing_input_raises_dataset_error(self, dataset_local, mock_opik, mock_remote_dataset): - """Item missing 'input' key raises DatasetError before any upload.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - bad_items = [{"expected_output": "no input here"}] - with pytest.raises(DatasetError, match="missing required 'input'"): - dataset_local.save(bad_items) - mock_remote_dataset.insert.assert_not_called() - - -class TestExists: - """Test the _exists() method.""" - - def test_returns_true_when_dataset_exists(self, dataset_local, mock_opik, mock_remote_dataset): - """Returns True when get_dataset succeeds.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - assert dataset_local._exists() is True - - def test_returns_false_on_404(self, dataset_local, mock_opik): - """Returns False when get_dataset raises a 404 ApiError.""" - mock_opik.get_dataset.side_effect = make_api_error(404) - assert dataset_local._exists() is False - - def test_non_404_api_error_raises_dataset_error(self, dataset_local, mock_opik): - """Non-404 ApiError is wrapped in DatasetError.""" - mock_opik.get_dataset.side_effect = make_api_error(500) - with pytest.raises(DatasetError, match="Opik API error while checking dataset"): - dataset_local._exists() - - def test_connection_error_raises_dataset_error(self, dataset_local, mock_opik): - """Connection-level errors are wrapped in DatasetError.""" - mock_opik.get_dataset.side_effect = ConnectionRefusedError("Connection refused") - with pytest.raises(DatasetError, match="Failed to connect to Opik while checking dataset"): - dataset_local._exists() - - -class TestDescribe: - """Test the _describe() method.""" - - def test_describe_returns_all_fields(self, dataset_local): - """_describe returns the expected keys.""" - desc = dataset_local._describe() - assert desc["dataset_name"] == "test-dataset" - assert desc["sync_policy"] == "local" - assert "filepath" in desc - assert "metadata" in desc - - def test_describe_filepath_none_when_not_set(self, dataset_remote): - """filepath is None in _describe when not configured.""" - assert dataset_remote._describe()["filepath"] is None - - def test_describe_metadata_included(self, mock_credentials, mock_opik, mock_remote_dataset): - """metadata dict is returned in _describe.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - metadata={"project": "evaluation"}, - ) - assert ds._describe()["metadata"] == {"project": "evaluation"} - - -class TestPreview: - """Test the preview() method.""" - - def test_preview_existing_json_file(self, dataset_local, eval_items): - """Returns a JSON-parseable preview for an existing file.""" - preview = dataset_local.preview() - parsed = json.loads(str(preview)) - assert isinstance(parsed, list) - assert len(parsed) == len(eval_items) - - def test_preview_nonexistent_file(self, tmp_path, mock_credentials, mock_opik, mock_remote_dataset): - """Returns a descriptive message when the local file does not exist.""" - mock_opik.get_dataset.return_value = mock_remote_dataset - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=str(tmp_path / "missing.json"), - ) - assert "does not exist" in str(ds.preview()) - - def test_preview_no_filepath(self, dataset_remote): - """Returns a descriptive message when no filepath is configured.""" - assert "No filepath configured" in str(dataset_remote.preview()) - - def test_preview_non_serialisable_data_returns_message( - self, tmp_path, mock_credentials, mock_opik - ): - """Non-JSON-serialisable local data returns a graceful error message instead of raising.""" - - filepath = tmp_path / "eval.json" - filepath.write_text(json.dumps([{"input": "x"}])) - - ds = OpikEvaluationDataset( - dataset_name="ds", - credentials=mock_credentials, - filepath=str(filepath), - ) - - with patch.object(ds.file_dataset, "load", return_value=[{"input": datetime.date(2024, 1, 1)}]): - result = str(ds.preview()) - - assert "Could not serialise" in result From adbf9a48eb059a95620aa75d75fe8132b89eee09 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Tue, 12 May 2026 16:16:18 -0500 Subject: [PATCH 20/21] Add huggingface install options, remove chromadb fixes. Signed-off-by: iwhalen --- kedro-datasets/pyproject.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index ad0e3911c..454a7bdb0 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -80,7 +80,11 @@ holoviews = ["kedro-datasets[holoviews-holoviewswriter]"] huggingface-hfdataset = ["datasets", "huggingface_hub"] huggingface-hftransformerpipelinedataset = ["transformers"] -huggingface = ["kedro-datasets[huggingface-hfdataset,huggingface-hftransformerpipelinedataset]"] +huggingface-arrowdataset = ["datasets"] +huggingface-csvdataset = ["datasets"] +huggingface-jsondataset = ["datasets"] +huggingface-parquetdataset = ["datasets"] +huggingface = ["kedro-datasets[huggingface-hfdataset,huggingface-hftransformerpipelinedataset,huggingface-arrowdataset,huggingface-csvdataset,huggingface-jsondataset,huggingface-parquetdataset]"] ibis-athena = ["ibis-framework[athena]"] ibis-bigquery = ["ibis-framework[bigquery]"] @@ -210,10 +214,7 @@ yaml-yamldataset = ["kedro-datasets[pandas-base]", "PyYAML>=4.2, <7.0"] yaml = ["kedro-datasets[yaml-yamldataset]"] # Experimental Datasets -chromadb-chromadbdataset = [ - "chromadb>=1.0.0", - "opentelemetry-exporter-otlp-proto-grpc>=1.40.0; python_version >= '3.14'", -] +chromadb-chromadbdataset = ["chromadb>=1.0.0"] chromadb = ["kedro-datasets[chromadb-chromadbdataset]"] darts-torch-model-dataset = ["u8darts-all"] @@ -276,7 +277,6 @@ test = [ "adlfs~=2023.1", "biopython~=1.73", "chromadb>=1.0.0", - "opentelemetry-exporter-otlp-proto-grpc>=1.40.0; python_version >= '3.14'", "cloudpickle~=2.2.1; python_version < '3.14'", "cloudpickle>=3.1.2; python_version >= '3.14'", "compress-pickle[lz4]~=2.1.0", From d57b65cc950797adee423228d70849b68be741c1 Mon Sep 17 00:00:00 2001 From: iwhalen Date: Wed, 13 May 2026 17:04:42 -0500 Subject: [PATCH 21/21] Update handling. Signed-off-by: iwhalen --- .../kedro_datasets/huggingface/_base.py | 76 +++++++++++----- .../kedro_datasets/huggingface/csv_dataset.py | 15 ++- .../huggingface/json_dataset.py | 15 ++- .../huggingface/parquet_dataset.py | 15 ++- .../jsonschema/kedro-catalog-1.0.0.json | 8 +- .../huggingface/test_filesystem_datasets.py | 91 +++++++++++++++++-- 6 files changed, 161 insertions(+), 59 deletions(-) diff --git a/kedro-datasets/kedro_datasets/huggingface/_base.py b/kedro-datasets/kedro_datasets/huggingface/_base.py index dafb815c1..1e79375f4 100644 --- a/kedro-datasets/kedro_datasets/huggingface/_base.py +++ b/kedro-datasets/kedro_datasets/huggingface/_base.py @@ -39,6 +39,7 @@ def __init__( # noqa: PLR0913 *, path: str | os.PathLike, version: Version | None = None, + data_files: dict[str, str] | None = None, load_args: dict[str, Any] | None = None, save_args: dict[str, Any] | None = None, credentials: dict[str, Any] | None = None, @@ -53,16 +54,20 @@ def __init__( # noqa: PLR0913 and remote URIs (e.g. ``s3://bucket/data``). version: Optional versioning configuration (see :class:`~kedro.io.core.Version`). + data_files: Mapping of split name to filename for loading and + saving a ``DatasetDict`` from a directory + (e.g. ``{"train": "train.csv"}``). The keys must match + the split names of the ``DatasetDict`` being saved, and + the filenames must use the correct extension for the + format (e.g. ``.csv`` for ``CSVDataset``). load_args: Additional keyword arguments passed to the - underlying load function. When loading a ``DatasetDict`` - from a directory, supply ``data_files`` as a mapping of - split name to filename (e.g. ``{"train": "train.csv"}``). - The keys must match the split names of the ``DatasetDict`` - being saved, and the filenames must use the correct - extension for the format (e.g. ``.csv`` for - ``CSVDataset``). + underlying load function. For backwards compatibility, + this may include ``data_files`` if the top-level + ``data_files`` argument is not used. save_args: Additional keyword arguments passed to the - underlying save function. + underlying save function. For backwards compatibility, + this may include ``data_files`` if the top-level + ``data_files`` argument is not used. credentials: Credentials for the underlying filesystem (e.g. ``key``/``secret`` for S3). Passed to the ``storage_options`` parameter in the underlying @@ -84,8 +89,11 @@ def __init__( # noqa: PLR0913 self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) - self._load_args = load_args or {} - self._save_args = save_args or {} + self._load_args = deepcopy(load_args or {}) + self._save_args = deepcopy(save_args or {}) + self._load_data_files, self._save_data_files = self._resolve_data_files( + data_files + ) self.metadata = metadata self._storage_options = {**_credentials, **_fs_args} or None @@ -130,11 +138,10 @@ def save(self, data: DatasetLike) -> None: def _build_data_files(self) -> str | dict[str, str]: load_path = get_filepath_str(self._get_load_path(), self._protocol) - if "data_files" in self._load_args: - data_files = self._load_args["data_files"] + if self._load_data_files: return { split: os.path.join(load_path, filename) - for split, filename in data_files.items() + for split, filename in self._load_data_files.items() } # Otherwise, just return the path to the Dataset to be loaded. @@ -143,14 +150,11 @@ def _build_data_files(self) -> str | dict[str, str]: def _load_dataset(self, load_path: str) -> DatasetLike: data_files: str | dict[str, str] = self._build_data_files() - load_args = deepcopy(self._load_args) - load_args.pop("data_files", None) - return load_dataset( # nosec self.BUILDER, data_files=data_files, storage_options=self._storage_options, - **load_args, + **self._load_args, ) def _save_dataset(self, data: Dataset, save_path: str) -> None: @@ -166,13 +170,13 @@ def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: As a result, we have to call ``to_`` per split. """ - if "data_files" in self._load_args: - data_files_keys = set(self._load_args["data_files"].keys()) + if self._save_data_files: + data_files_keys = set(self._save_data_files.keys()) split_names = set(data.keys()) if data_files_keys != split_names: msg = ( f"DatasetDict split names {sorted(split_names)} do not match " - f"``load_args['data_files']`` keys {sorted(data_files_keys)}. " + f"``data_files`` keys {sorted(data_files_keys)}. " "The data_files keys must match the DatasetDict split names " "so the saved files can be found on load." ) @@ -180,15 +184,45 @@ def _save_dataset_dict(self, data: DatasetDict, save_path: str) -> None: self._fs.mkdirs(save_path, exist_ok=True) ext = self.EXTENSION for split, split_ds in data.items(): - split_path = f"{save_path}/{split}{ext}" + filename = ( + self._save_data_files[split] + if self._save_data_files + else f"{split}{ext}" + ) + split_path = f"{save_path}/{filename}" self._save_dataset(split_ds, split_path) + def _resolve_data_files( + self, data_files: dict[str, str] | None + ) -> tuple[dict[str, str] | None, dict[str, str] | None]: + if data_files is not None and ( + "data_files" in self._load_args or "data_files" in self._save_args + ): + msg = ( + f"{type(self).__name__} got ``data_files`` as a top-level " + "argument and in ``load_args`` or ``save_args``. Pass it " + "in only one place." + ) + raise DatasetError(msg) + + load_data_files, save_data_files = None, None + + if data_files is not None: + save_data_files = load_data_files = deepcopy(data_files) + + else: + load_data_files = deepcopy(self._load_args.pop("data_files", None)) + save_data_files = deepcopy(self._save_args.pop("data_files", None)) + + return load_data_files, save_data_files + def _describe(self) -> dict[str, Any]: return { "path": self._filepath, "file_format": self.BUILDER, "protocol": self._protocol, "version": self._version, + "data_files": self._load_data_files, "load_args": self._load_args, "save_args": self._save_args, } diff --git a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py index 71503193d..73fc62a08 100644 --- a/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/csv_dataset.py @@ -82,10 +82,9 @@ class CSVDataset(FilesystemDataset): reviews: type: huggingface.CSVDataset path: data/01_raw/reviews - load_args: - data_files: - labels: labels.csv - data: data.csv + data_files: + labels: labels.csv + data: data.csv ``` Using the @@ -103,11 +102,9 @@ class CSVDataset(FilesystemDataset): ... }) >>> dataset = CSVDataset( ... path=tmp_path, - ... load_args={ - ... "data_files": { - ... "labels": "labels.csv", - ... "data": "data.csv", - ... } + ... data_files={ + ... "labels": "labels.csv", + ... "data": "data.csv", ... }, ... ) >>> dataset.save(dataset_dict) diff --git a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py index e6c9d564b..6c591b177 100644 --- a/kedro-datasets/kedro_datasets/huggingface/json_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/json_dataset.py @@ -82,10 +82,9 @@ class JSONDataset(FilesystemDataset): reviews: type: huggingface.JSONDataset path: data/01_raw/reviews - load_args: - data_files: - labels: labels.json - data: data.json + data_files: + labels: labels.json + data: data.json ``` Using the @@ -103,11 +102,9 @@ class JSONDataset(FilesystemDataset): ... }) >>> dataset = JSONDataset( ... path=tmp_path, - ... load_args={ - ... "data_files": { - ... "labels": "labels.json", - ... "data": "data.json", - ... } + ... data_files={ + ... "labels": "labels.json", + ... "data": "data.json", ... }, ... ) >>> dataset.save(dataset_dict) diff --git a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py index 2542a3f34..ea1f97aab 100644 --- a/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/parquet_dataset.py @@ -83,10 +83,9 @@ class ParquetDataset(FilesystemDataset): reviews: type: huggingface.ParquetDataset path: data/01_raw/reviews - load_args: - data_files: - labels: labels.parquet - data: data.parquet + data_files: + labels: labels.parquet + data: data.parquet ``` Using the @@ -104,11 +103,9 @@ class ParquetDataset(FilesystemDataset): ... }) >>> dataset = ParquetDataset( ... path=tmp_path, - ... load_args={ - ... "data_files": { - ... "labels": "labels.parquet", - ... "data": "data.parquet", - ... } + ... data_files={ + ... "labels": "labels.parquet", + ... "data": "data.parquet", ... }, ... ) >>> dataset.save(dataset_dict) diff --git a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json index bee9ca925..48c0a281a 100644 --- a/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json +++ b/kedro-datasets/static/jsonschema/kedro-catalog-1.0.0.json @@ -475,13 +475,17 @@ "type": "string", "description": "Path to a directory or file for persisting Hugging Face datasets. Supports local and remote filesystems (e.g. s3://)." }, + "data_files": { + "type": "object", + "description": "Mapping of split name to filename for loading and saving a DatasetDict from a directory." + }, "load_args": { "type": "object", - "description": "Additional arguments passed to the load method." + "description": "Additional arguments passed to the `datasets.load_dataset` function." }, "save_args": { "type": "object", - "description": "Additional arguments passed to the save method." + "description": "Additional arguments passed to the underlying dataset's save method." } } } diff --git a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py index cb1d1ae2c..8f60739b8 100644 --- a/kedro-datasets/tests/huggingface/test_filesystem_datasets.py +++ b/kedro-datasets/tests/huggingface/test_filesystem_datasets.py @@ -56,6 +56,14 @@ def dataset_dict_data_files(extension): } +@pytest.fixture +def custom_dataset_dict_data_files(extension): + return { + "data": f"my_data{extension}", + "labels": f"my_labels{extension}", + } + + @pytest.fixture def path_file(tmp_path, extension): return (tmp_path / f"test{extension}").as_posix() @@ -107,7 +115,7 @@ def test_build_data_files( self, kedro_dataset_cls, path_dir, dataset_dict_data_files ): kedro_dataset = kedro_dataset_cls( - path=path_dir, load_args={"data_files": dataset_dict_data_files} + path=path_dir, data_files=dataset_dict_data_files ) built_data_files = kedro_dataset._build_data_files() @@ -120,7 +128,7 @@ def test_save_and_load_dataset_dict( self, dataset_dict, kedro_dataset_cls, path_dir, dataset_dict_data_files ): kedro_dataset = kedro_dataset_cls( - path=path_dir, load_args={"data_files": dataset_dict_data_files} + path=path_dir, data_files=dataset_dict_data_files ) kedro_dataset.save(dataset_dict) @@ -130,23 +138,88 @@ def test_save_and_load_dataset_dict( for key in dataset_dict_data_files.keys(): assert reloaded[key].to_dict() == dataset_dict[key].to_dict() + def test_save_and_load_dataset_dict_with_custom_data_files( + self, + dataset_dict, + kedro_dataset_cls, + path_dir, + custom_dataset_dict_data_files, + ): + kedro_dataset = kedro_dataset_cls( + path=path_dir, data_files=custom_dataset_dict_data_files + ) + kedro_dataset.save(dataset_dict) + + for filename in custom_dataset_dict_data_files.values(): + assert os.path.exists(os.path.join(path_dir, filename)) + + reloaded = kedro_dataset.load() + assert isinstance(reloaded, DatasetDict) + assert set(reloaded.keys()) == custom_dataset_dict_data_files.keys() + for key in custom_dataset_dict_data_files.keys(): + assert reloaded[key].to_dict() == dataset_dict[key].to_dict() + + def test_load_and_save_data_files_can_differ( + self, + dataset_dict, + kedro_dataset_cls, + path_dir, + extension, + ): + load_data_files = { + "data": f"load_data{extension}", + "labels": f"load_labels{extension}", + } + save_data_files = { + "data": f"save_data{extension}", + "labels": f"save_labels{extension}", + } + kedro_dataset = kedro_dataset_cls( + path=path_dir, + load_args={"data_files": load_data_files}, + save_args={"data_files": save_data_files}, + ) + + built_data_files = kedro_dataset._build_data_files() + for split, filename in load_data_files.items(): + assert built_data_files[split] == os.path.join(path_dir, filename) + + kedro_dataset.save(dataset_dict) + for filename in save_data_files.values(): + assert os.path.exists(os.path.join(path_dir, filename)) + def test_save_dataset_dict_mismatched_data_files( self, dataset_dict, kedro_dataset_cls, path_dir, extension ): """Saving a DatasetDict whose split names don't match data_files keys raises DatasetError.""" kedro_dataset = kedro_dataset_cls( path=path_dir, - load_args={ - # In the test fixture, we expect "data" and "labels". Not "train" and "test". - "data_files": { - "train": f"train{extension}", - "test": f"test{extension}", - } + # In the test fixture, we expect "data" and "labels". Not "train" and "test". + data_files={ + "train": f"train{extension}", + "test": f"test{extension}", }, ) with pytest.raises(DatasetError, match=r"do not match"): kedro_dataset.save(dataset_dict) + @pytest.mark.parametrize("args_name", ["load_args", "save_args"]) + def test_top_level_data_files_conflicts_with_data_files_in_args( + self, + kedro_dataset_cls, + path_dir, + dataset_dict_data_files, + args_name, + ): + args = { + "path": path_dir, + "data_files": dataset_dict_data_files, + args_name: {"data_files": dataset_dict_data_files}, + } + + with pytest.raises(DatasetError, match=r"top-level"): + kedro_dataset_cls(**args) + def test_save_and_load_iterable_dataset( self, iterable_dataset, kedro_dataset_cls, path_file ): @@ -162,7 +235,7 @@ def test_save_and_load_iterable_dataset_dict( dataset_dict_data_files, ): kedro_dataset = kedro_dataset_cls( - path=path_dir, load_args={"data_files": dataset_dict_data_files} + path=path_dir, data_files=dataset_dict_data_files ) with pytest.raises(DatasetError, match=r"got iterable dataset"): kedro_dataset.save(iterable_dataset_dict)