Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include <limits>
#include <numeric>
#include <sstream>

Expand Down Expand Up @@ -109,8 +110,8 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
return Status::Invalid("Expected FixedSizeList storage type, got ",
storage_type->ToString());
}
auto value_type =
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
auto fsl_type = internal::checked_pointer_cast<FixedSizeListType>(storage_type);
auto value_type = fsl_type->value_type();
rj::Document document;
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
!document.IsObject() || !document.HasMember("shape") ||
Expand All @@ -119,29 +120,61 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
}

std::vector<int64_t> shape;
for (auto& x : document["shape"].GetArray()) {
for (const auto& x : document["shape"].GetArray()) {
if (!x.IsInt64()) {
return Status::Invalid("shape must contain integers");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we show the actual type for easy to debug on failure?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. Now we emit value and type of the value on deserialization failure.

}
shape.emplace_back(x.GetInt64());
}

std::vector<int64_t> permutation;
if (document.HasMember("permutation")) {
for (auto& x : document["permutation"].GetArray()) {
const auto& json_permutation = document["permutation"];
if (!json_permutation.IsArray()) {
return Status::Invalid("permutation must be an array");
Comment thread
rok marked this conversation as resolved.
Outdated
}
for (const auto& x : json_permutation.GetArray()) {
if (!x.IsInt64()) {
return Status::Invalid("permutation must contain integers");
Comment thread
rok marked this conversation as resolved.
Outdated
}
permutation.emplace_back(x.GetInt64());
}
if (shape.size() != permutation.size()) {
return Status::Invalid("Invalid permutation");
}
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
}
std::vector<std::string> dim_names;
if (document.HasMember("dim_names")) {
for (auto& x : document["dim_names"].GetArray()) {
const auto& json_dim_names = document["dim_names"];
if (!json_dim_names.IsArray()) {
return Status::Invalid("dim_names must be an array");
Comment thread
rok marked this conversation as resolved.
Outdated
}
for (const auto& x : json_dim_names.GetArray()) {
if (!x.IsString()) {
return Status::Invalid("dim_names must contain strings");
Comment thread
rok marked this conversation as resolved.
Outdated
}
dim_names.emplace_back(x.GetString());
}
if (shape.size() != dim_names.size()) {
return Status::Invalid("Invalid dim_names");
}
}

return fixed_shape_tensor(value_type, shape, permutation, dim_names);
// Validate product of shape dimensions matches storage type list_size.
// This check is intentionally after field parsing so that metadata-level errors
// (type mismatches, size mismatches) are reported first.
ARROW_ASSIGN_OR_RAISE(auto ext_type, FixedShapeTensorType::Make(
value_type, shape, permutation, dim_names));
const auto& fst_type = internal::checked_cast<const FixedShapeTensorType&>(*ext_type);
ARROW_ASSIGN_OR_RAISE(const int64_t expected_size,
internal::ComputeShapeProduct(fst_type.shape()));
if (expected_size != fsl_type->list_size()) {
return Status::Invalid("Product of shape dimensions (", expected_size,
") does not match FixedSizeList size (", fsl_type->list_size(),
")");
}
return ext_type;
}

std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
Expand Down Expand Up @@ -310,8 +343,7 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
}

std::vector<int64_t> shape = ext_type.shape();
auto cell_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
ARROW_ASSIGN_OR_RAISE(const int64_t cell_size, internal::ComputeShapeProduct(shape));
shape.insert(shape.begin(), 1, this->length());
internal::Permute<int64_t>(permutation, &shape);

Expand All @@ -330,6 +362,11 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
const size_t ndim = shape.size();
for (auto dim : shape) {
if (dim < 0) {
return Status::Invalid("shape must have non-negative values, got ", dim);
}
}
if (!permutation.empty() && ndim != permutation.size()) {
return Status::Invalid("permutation size must match shape size. Expected: ", ndim,
" Got: ", permutation.size());
Expand All @@ -342,8 +379,12 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
}

const int64_t size = std::accumulate(shape.begin(), shape.end(),
static_cast<int64_t>(1), std::multiplies<>());
ARROW_ASSIGN_OR_RAISE(const int64_t size, internal::ComputeShapeProduct(shape));
if (size > std::numeric_limits<int32_t>::max()) {
return Status::Invalid("Product of shape dimensions (", size,
") exceeds maximum FixedSizeList size (",
std::numeric_limits<int32_t>::max(), ")");
}
return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size),
shape, permutation, dim_names);
}
Expand Down
53 changes: 53 additions & 0 deletions cpp/src/arrow/extension/tensor_extension_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,59 @@ TEST_F(TestFixedShapeTensorType, MetadataSerializationRoundtrip) {
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3],"dim_names":["x","y"]})",
"Invalid dim_names");

// Validate shape values must be integers
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3.5,4]})",
"shape must contain integers");
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":["3","4"]})",
"shape must contain integers");
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[null]})",
"shape must contain integers");

// Validate shape values must be non-negative
CheckDeserializationRaises(ext_type_, fixed_size_list(int64(), 1), R"({"shape":[-1]})",
"shape must have non-negative values");

// Validate product of shape matches storage list_size
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3,3]})",
"Product of shape dimensions");

// Validate permutation member must be an array with integer values
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":"invalid"})",
"permutation must be an array");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[1.5,0.5]})",
"permutation must contain integers");

// Validate permutation values must be unique integers in [0, N-1]
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[0,0]})",
"Permutation indices");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[0,5]})",
"Permutation indices");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[-1,0]})",
"Permutation indices");

// Validate dim_names member must be an array with string values
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"dim_names":"invalid"})",
"dim_names must be an array");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"dim_names":[1,2]})",
"dim_names must contain strings");
}

TEST_F(TestFixedShapeTensorType, MakeValidatesShape) {
// Negative shape values should be rejected
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, testing::HasSubstr("shape must have non-negative values"),
FixedShapeTensorType::Make(value_type_, {-1}));
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, testing::HasSubstr("shape must have non-negative values"),
FixedShapeTensorType::Make(value_type_, {3, -1, 4}));
}

TEST_F(TestFixedShapeTensorType, RoundtripBatch) {
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/arrow/extension/tensor_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@

namespace arrow::internal {

Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape) {
int64_t product = 1;
for (const auto dim : shape) {
if (MultiplyWithOverflow(product, dim, &product)) {
return Status::Invalid(
"Product of tensor shape dimensions would not fit in 64-bit integer");
}
}
return product;
}

bool IsPermutationTrivial(std::span<const int64_t> permutation) {
for (size_t i = 1; i < permutation.size(); ++i) {
if (permutation[i - 1] + 1 != permutation[i]) {
Expand Down Expand Up @@ -105,12 +116,7 @@ Result<std::shared_ptr<Buffer>> SliceTensorBuffer(const Array& data_array,
const DataType& value_type,
std::span<const int64_t> shape) {
const int64_t byte_width = value_type.byte_width();
int64_t size = 1;
for (const auto dim : shape) {
if (MultiplyWithOverflow(size, dim, &size)) {
return Status::Invalid("Tensor size would not fit in 64-bit integer");
}
}
ARROW_ASSIGN_OR_RAISE(const int64_t size, ComputeShapeProduct(shape));
if (size != data_array.length()) {
return Status::Invalid("Expected data array of length ", size, ", got ",
data_array.length());
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/arrow/extension/tensor_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@

namespace arrow::internal {

/// \brief Compute the product of the given shape dimensions.
///
/// Returns Status::Invalid if the product would overflow int64_t.
/// An empty shape returns 1 (the multiplicative identity).
ARROW_EXPORT
Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape);

ARROW_EXPORT
bool IsPermutationTrivial(std::span<const int64_t> permutation);

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/extension/variable_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
}
permutation.emplace_back(x.GetInt64());
}
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
}
std::vector<std::string> dim_names;
if (document.HasMember("dim_names")) {
Expand Down
Loading