Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
return Status::Invalid("Expected FixedSizeList storage type, got ",
storage_type->ToString());
}
auto value_type =
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
auto fsl_type = internal::checked_pointer_cast<FixedSizeListType>(storage_type);
auto value_type = fsl_type->value_type();
rj::Document document;
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
!document.IsObject() || !document.HasMember("shape") ||
Expand All @@ -119,12 +119,23 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
}

std::vector<int64_t> shape;
for (auto& x : document["shape"].GetArray()) {
for (const auto& x : document["shape"].GetArray()) {
if (!x.IsInt64()) {
return Status::Invalid("shape must contain integers");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we show the actual type for easy to debug on failure?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. Now we emit value and type of the value on deserialization failure.

}
shape.emplace_back(x.GetInt64());
}

std::vector<int64_t> permutation;
if (document.HasMember("permutation")) {
for (auto& x : document["permutation"].GetArray()) {
const auto& json_permutation = document["permutation"];
if (!json_permutation.IsArray()) {
return Status::Invalid("permutation must be an array");
Comment thread
rok marked this conversation as resolved.
Outdated
}
for (const auto& x : json_permutation.GetArray()) {
if (!x.IsInt64()) {
return Status::Invalid("permutation must contain integers");
Comment thread
rok marked this conversation as resolved.
Outdated
}
permutation.emplace_back(x.GetInt64());
}
if (shape.size() != permutation.size()) {
Expand All @@ -133,15 +144,36 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
}
std::vector<std::string> dim_names;
if (document.HasMember("dim_names")) {
for (auto& x : document["dim_names"].GetArray()) {
const auto& json_dim_names = document["dim_names"];
if (!json_dim_names.IsArray()) {
return Status::Invalid("dim_names must be an array");
Comment thread
rok marked this conversation as resolved.
Outdated
}
for (const auto& x : json_dim_names.GetArray()) {
if (!x.IsString()) {
return Status::Invalid("dim_names must contain strings");
Comment thread
rok marked this conversation as resolved.
Outdated
}
dim_names.emplace_back(x.GetString());
}
if (shape.size() != dim_names.size()) {
return Status::Invalid("Invalid dim_names");
}
}

return fixed_shape_tensor(value_type, shape, permutation, dim_names);
// Validate product of shape dimensions matches storage type list_size.
// This check is intentionally after field parsing so that metadata-level errors
// (type mismatches, size mismatches) are reported first.
ARROW_ASSIGN_OR_RAISE(auto ext_type, FixedShapeTensorType::Make(
value_type, shape, permutation, dim_names));
const auto& fst_type = internal::checked_cast<const FixedShapeTensorType&>(*ext_type);
const int64_t expected_size =
std::accumulate(fst_type.shape().begin(), fst_type.shape().end(),
static_cast<int64_t>(1), std::multiplies<>());
Comment thread
rok marked this conversation as resolved.
Outdated
if (expected_size != fsl_type->list_size()) {
return Status::Invalid("Product of shape dimensions (", expected_size,
") does not match FixedSizeList size (", fsl_type->list_size(),
")");
}
return ext_type;
}

std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
Expand Down Expand Up @@ -330,6 +362,11 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
const size_t ndim = shape.size();
for (auto dim : shape) {
if (dim < 0) {
return Status::Invalid("shape must have non-negative values, got ", dim);
}
}
if (!permutation.empty() && ndim != permutation.size()) {
return Status::Invalid("permutation size must match shape size. Expected: ", ndim,
" Got: ", permutation.size());
Expand Down
53 changes: 53 additions & 0 deletions cpp/src/arrow/extension/tensor_extension_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,59 @@ TEST_F(TestFixedShapeTensorType, MetadataSerializationRoundtrip) {
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3],"dim_names":["x","y"]})",
"Invalid dim_names");

// Validate shape values must be integers
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3.5,4]})",
"shape must contain integers");
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":["3","4"]})",
"shape must contain integers");
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[null]})",
"shape must contain integers");

// Validate shape values must be non-negative
CheckDeserializationRaises(ext_type_, fixed_size_list(int64(), 1), R"({"shape":[-1]})",
"shape must have non-negative values");

// Validate product of shape matches storage list_size
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3,3]})",
"Product of shape dimensions");

// Validate permutation member must be an array with integer values
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":"invalid"})",
"permutation must be an array");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[1.5,0.5]})",
"permutation must contain integers");

// Validate permutation values must be unique integers in [0, N-1]
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[0,0]})",
"Permutation indices");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[0,5]})",
"Permutation indices");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[-1,0]})",
"Permutation indices");

// Validate dim_names member must be an array with string values
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"dim_names":"invalid"})",
"dim_names must be an array");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"dim_names":[1,2]})",
"dim_names must contain strings");
}

TEST_F(TestFixedShapeTensorType, MakeValidatesShape) {
// Negative shape values should be rejected
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, testing::HasSubstr("shape must have non-negative values"),
FixedShapeTensorType::Make(value_type_, {-1}));
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, testing::HasSubstr("shape must have non-negative values"),
FixedShapeTensorType::Make(value_type_, {3, -1, 4}));
}

TEST_F(TestFixedShapeTensorType, RoundtripBatch) {
Expand Down
Loading