Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.stats import DatasetStats
from ray.data._internal.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data._internal.tensor_extensions.utils import (
_create_possibly_ragged_ndarray,
create_ragged_ndarray,
)
from ray.data._internal.util import (
_autodetect_parallelism,
get_compute_strategy_for_read_api,
Expand Down Expand Up @@ -3835,8 +3838,37 @@ def from_tf(
Returns:
A :class:`MaterializedDataset` that contains the samples stored in the `TensorFlow Dataset`_.
""" # noqa: E501
# FIXME: `as_numpy_iterator` errors if `dataset` contains ragged tensors.
return from_items(list(dataset.as_numpy_iterator()))
import tensorflow as tf

def _contains_ragged_tensor_spec(spec: Any) -> bool:
return any(
isinstance(type_spec, tf.RaggedTensorSpec)
for type_spec in tf.nest.flatten(spec)
)

def _convert_tf_value(value: Any) -> Any:
if isinstance(value, tf.RaggedTensor):
return create_ragged_ndarray([_convert_tf_value(v) for v in value])
if isinstance(value, tf.SparseTensor):
return tf.compat.v1.SparseTensorValue(
indices=value.indices.numpy(),
values=value.values.numpy(),
dense_shape=value.dense_shape.numpy(),
)
if isinstance(value, tf.Tensor):
return value.numpy()
if isinstance(value, collections.abc.Mapping):
return {key: _convert_tf_value(item) for key, item in value.items()}
if isinstance(value, tuple):
return tuple(_convert_tf_value(item) for item in value)
if isinstance(value, list):
return [_convert_tf_value(item) for item in value]
return value

if not _contains_ragged_tensor_spec(dataset.element_spec):
return from_items(list(dataset.as_numpy_iterator()))

return from_items([_convert_tf_value(item) for item in dataset])
Comment thread
weimingdiit marked this conversation as resolved.


@PublicAPI
Expand Down
57 changes: 57 additions & 0 deletions python/ray/data/tests/datasource/test_tensorflow_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys

import numpy as np
import pytest

import ray
Expand All @@ -9,6 +10,20 @@
from ray.tests.conftest import * # noqa


def _to_nested_lists(value):
if isinstance(value, np.ndarray):
if value.dtype == object:
return [_to_nested_lists(item) for item in value.tolist()]
return value.tolist()
if isinstance(value, tuple):
return tuple(_to_nested_lists(item) for item in value)
if isinstance(value, list):
return [_to_nested_lists(item) for item in value]
if isinstance(value, dict):
return {key: _to_nested_lists(item) for key, item in value.items()}
return value


def test_from_tf_e2e(ray_start_regular_shared_2_cpus):
import tensorflow as tf
import tensorflow_datasets as tfds
Expand All @@ -34,5 +49,47 @@ def test_from_tf_e2e(ray_start_regular_shared_2_cpus):
_check_usage_record(["FromItems"])


def test_from_tf_ragged_tensor(ray_start_regular_shared_2_cpus):
import tensorflow as tf

tf_dataset = tf.data.Dataset.from_tensors(
tf.ragged.constant([[1, 2, 3], [4, 5]])
).concatenate(tf.data.Dataset.from_tensors(tf.ragged.constant([[6], [7, 8]])))

ray_dataset = ray.data.from_tf(tf_dataset)

actual_data = extract_values("item", ray_dataset.take_all())
expected_data = list(tf_dataset)

assert len(actual_data) == len(expected_data)
for actual_item, expected_item in zip(actual_data, expected_data):
assert _to_nested_lists(actual_item) == expected_item.to_list()


def test_from_tf_ragged_and_sparse_tensor(ray_start_regular_shared_2_cpus):
import tensorflow as tf

ragged_tensor = tf.ragged.constant([[1, 2, 3], [4, 5]])
sparse_tensor = tf.sparse.from_dense([[1, 0, 0], [0, 2, 3]])
tf_dataset = tf.data.Dataset.from_tensors(
{"ragged": ragged_tensor, "sparse": sparse_tensor}
)

ray_dataset = ray.data.from_tf(tf_dataset)
actual_item = ray_dataset.take_all()[0]

assert _to_nested_lists(actual_item["ragged"]) == ragged_tensor.to_list()
assert isinstance(actual_item["sparse"], tf.compat.v1.SparseTensorValue)
np.testing.assert_array_equal(
actual_item["sparse"].indices, sparse_tensor.indices.numpy()
)
np.testing.assert_array_equal(
actual_item["sparse"].values, sparse_tensor.values.numpy()
)
np.testing.assert_array_equal(
actual_item["sparse"].dense_shape, sparse_tensor.dense_shape.numpy()
)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading