diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index a924de9f4710..6da9f6e1690e 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -79,7 +79,10 @@ from ray.data._internal.plan import ExecutionPlan from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.stats import DatasetStats -from ray.data._internal.tensor_extensions.utils import _create_possibly_ragged_ndarray +from ray.data._internal.tensor_extensions.utils import ( + _create_possibly_ragged_ndarray, + create_ragged_ndarray, +) from ray.data._internal.util import ( _autodetect_parallelism, get_compute_strategy_for_read_api, @@ -3835,8 +3838,37 @@ def from_tf( Returns: A :class:`MaterializedDataset` that contains the samples stored in the `TensorFlow Dataset`_. """ # noqa: E501 - # FIXME: `as_numpy_iterator` errors if `dataset` contains ragged tensors. - return from_items(list(dataset.as_numpy_iterator())) + import tensorflow as tf + + def _contains_ragged_tensor_spec(spec: Any) -> bool: + return any( + isinstance(type_spec, tf.RaggedTensorSpec) + for type_spec in tf.nest.flatten(spec) + ) + + def _convert_tf_value(value: Any) -> Any: + if isinstance(value, tf.RaggedTensor): + return create_ragged_ndarray([_convert_tf_value(v) for v in value]) + if isinstance(value, tf.SparseTensor): + return tf.compat.v1.SparseTensorValue( + indices=value.indices.numpy(), + values=value.values.numpy(), + dense_shape=value.dense_shape.numpy(), + ) + if isinstance(value, tf.Tensor): + return value.numpy() + if isinstance(value, collections.abc.Mapping): + return {key: _convert_tf_value(item) for key, item in value.items()} + if isinstance(value, tuple): + return tuple(_convert_tf_value(item) for item in value) + if isinstance(value, list): + return [_convert_tf_value(item) for item in value] + return value + + if not _contains_ragged_tensor_spec(dataset.element_spec): + return from_items(list(dataset.as_numpy_iterator())) + + return from_items([_convert_tf_value(item) for item in dataset]) @PublicAPI diff --git a/python/ray/data/tests/datasource/test_tensorflow_datasets.py b/python/ray/data/tests/datasource/test_tensorflow_datasets.py index 5b7f3213c53c..d835da8644ea 100644 --- a/python/ray/data/tests/datasource/test_tensorflow_datasets.py +++ b/python/ray/data/tests/datasource/test_tensorflow_datasets.py @@ -1,5 +1,6 @@ import sys +import numpy as np import pytest import ray @@ -9,6 +10,20 @@ from ray.tests.conftest import * # noqa +def _to_nested_lists(value): + if isinstance(value, np.ndarray): + if value.dtype == object: + return [_to_nested_lists(item) for item in value.tolist()] + return value.tolist() + if isinstance(value, tuple): + return tuple(_to_nested_lists(item) for item in value) + if isinstance(value, list): + return [_to_nested_lists(item) for item in value] + if isinstance(value, dict): + return {key: _to_nested_lists(item) for key, item in value.items()} + return value + + def test_from_tf_e2e(ray_start_regular_shared_2_cpus): import tensorflow as tf import tensorflow_datasets as tfds @@ -34,5 +49,47 @@ def test_from_tf_e2e(ray_start_regular_shared_2_cpus): _check_usage_record(["FromItems"]) +def test_from_tf_ragged_tensor(ray_start_regular_shared_2_cpus): + import tensorflow as tf + + tf_dataset = tf.data.Dataset.from_tensors( + tf.ragged.constant([[1, 2, 3], [4, 5]]) + ).concatenate(tf.data.Dataset.from_tensors(tf.ragged.constant([[6], [7, 8]]))) + + ray_dataset = ray.data.from_tf(tf_dataset) + + actual_data = extract_values("item", ray_dataset.take_all()) + expected_data = list(tf_dataset) + + assert len(actual_data) == len(expected_data) + for actual_item, expected_item in zip(actual_data, expected_data): + assert _to_nested_lists(actual_item) == expected_item.to_list() + + +def test_from_tf_ragged_and_sparse_tensor(ray_start_regular_shared_2_cpus): + import tensorflow as tf + + ragged_tensor = tf.ragged.constant([[1, 2, 3], [4, 5]]) + sparse_tensor = tf.sparse.from_dense([[1, 0, 0], [0, 2, 3]]) + tf_dataset = tf.data.Dataset.from_tensors( + {"ragged": ragged_tensor, "sparse": sparse_tensor} + ) + + ray_dataset = ray.data.from_tf(tf_dataset) + actual_item = ray_dataset.take_all()[0] + + assert _to_nested_lists(actual_item["ragged"]) == ragged_tensor.to_list() + assert isinstance(actual_item["sparse"], tf.compat.v1.SparseTensorValue) + np.testing.assert_array_equal( + actual_item["sparse"].indices, sparse_tensor.indices.numpy() + ) + np.testing.assert_array_equal( + actual_item["sparse"].values, sparse_tensor.values.numpy() + ) + np.testing.assert_array_equal( + actual_item["sparse"].dense_shape, sparse_tensor.dense_shape.numpy() + ) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__]))