Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 2 additions & 196 deletions cpp/src/arrow/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include "arrow/type.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
#include "arrow/util/unreachable.h"
#include "arrow/util/vector.h"
#include "arrow/visit_type_inline.h"

Expand Down Expand Up @@ -286,204 +285,11 @@ Result<std::shared_ptr<StructArray>> RecordBatch::ToStructArray() const {
/*offset=*/0);
}

template <typename Out>
struct ConvertColumnsToTensorVisitor {
Out*& out_values;
const ArrayData& in_data;

template <typename T>
Status Visit(const T&) {
if constexpr (is_numeric(T::type_id)) {
using In = typename T::c_type;
auto in_values = ArraySpan(in_data).GetSpan<In>(1, in_data.length);

if (in_data.null_count == 0) {
if constexpr (std::is_same_v<In, Out>) {
memcpy(out_values, in_values.data(), in_values.size_bytes());
out_values += in_values.size();
} else {
for (In in_value : in_values) {
*out_values++ = static_cast<Out>(in_value);
}
}
} else {
for (int64_t i = 0; i < in_data.length; ++i) {
*out_values++ =
in_data.IsNull(i) ? static_cast<Out>(NAN) : static_cast<Out>(in_values[i]);
}
}
return Status::OK();
}
Unreachable();
}
};

template <typename Out>
struct ConvertColumnsToTensorRowMajorVisitor {
Out*& out_values;
const ArrayData& in_data;
int num_cols;
int col_idx;

template <typename T>
Status Visit(const T&) {
if constexpr (is_numeric(T::type_id)) {
using In = typename T::c_type;
auto in_values = ArraySpan(in_data).GetSpan<In>(1, in_data.length);

if (in_data.null_count == 0) {
for (int64_t i = 0; i < in_data.length; ++i) {
out_values[i * num_cols + col_idx] = static_cast<Out>(in_values[i]);
}
} else {
for (int64_t i = 0; i < in_data.length; ++i) {
out_values[i * num_cols + col_idx] =
in_data.IsNull(i) ? static_cast<Out>(NAN) : static_cast<Out>(in_values[i]);
}
}
return Status::OK();
}
Unreachable();
}
};

template <typename DataType>
inline void ConvertColumnsToTensor(const RecordBatch& batch, uint8_t* out,
bool row_major) {
using CType = typename arrow::TypeTraits<DataType>::CType;
auto* out_values = reinterpret_cast<CType*>(out);

int i = 0;
for (const auto& column : batch.columns()) {
if (row_major) {
ConvertColumnsToTensorRowMajorVisitor<CType> visitor{out_values, *column->data(),
batch.num_columns(), i++};
DCHECK_OK(VisitTypeInline(*column->type(), &visitor));
} else {
ConvertColumnsToTensorVisitor<CType> visitor{out_values, *column->data()};
DCHECK_OK(VisitTypeInline(*column->type(), &visitor));
}
}
}

Result<std::shared_ptr<Tensor>> RecordBatch::ToTensor(bool null_to_nan, bool row_major,
MemoryPool* pool) const {
if (num_columns() == 0) {
return Status::TypeError(
"Conversion to Tensor for RecordBatches without columns/schema is not "
"supported.");
}
// Check for no validity bitmap of each field
// if null_to_nan conversion is set to false
for (int i = 0; i < num_columns(); ++i) {
if (column(i)->null_count() > 0 && !null_to_nan) {
return Status::TypeError(
"Can only convert a RecordBatch with no nulls. Set null_to_nan to true to "
"convert nulls to NaN");
}
}

// Check for supported data types and merge fields
// to get the resulting uniform data type
if (!is_integer(column(0)->type()->id()) && !is_floating(column(0)->type()->id())) {
return Status::TypeError("DataType is not supported: ",
column(0)->type()->ToString());
}
std::shared_ptr<Field> result_field = schema_->field(0);
std::shared_ptr<DataType> result_type = result_field->type();

Field::MergeOptions options;
options.promote_integer_to_float = true;
options.promote_integer_sign = true;
options.promote_numeric_width = true;

if (num_columns() > 1) {
for (int i = 1; i < num_columns(); ++i) {
if (!is_numeric(column(i)->type()->id())) {
return Status::TypeError("DataType is not supported: ",
column(i)->type()->ToString());
}

// Casting of float16 is not supported, throw an error in this case
if ((column(i)->type()->id() == Type::HALF_FLOAT ||
result_field->type()->id() == Type::HALF_FLOAT) &&
column(i)->type()->id() != result_field->type()->id()) {
return Status::NotImplemented("Casting from or to halffloat is not supported.");
}

ARROW_ASSIGN_OR_RAISE(
result_field, result_field->MergeWith(
schema_->field(i)->WithName(result_field->name()), options));
}
result_type = result_field->type();
}

// Check if result_type is signed or unsigned integer and null_to_nan is set to true
// Then all columns should be promoted to float type
if (is_integer(result_type->id()) && null_to_nan) {
ARROW_ASSIGN_OR_RAISE(
result_field,
result_field->MergeWith(field(result_field->name(), float32()), options));
result_type = result_field->type();
}

// Allocate memory
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<Buffer> result,
AllocateBuffer(result_type->bit_width() * num_columns() * num_rows(), pool));
// Copy data
switch (result_type->id()) {
case Type::UINT8:
ConvertColumnsToTensor<UInt8Type>(*this, result->mutable_data(), row_major);
break;
case Type::UINT16:
case Type::HALF_FLOAT:
ConvertColumnsToTensor<UInt16Type>(*this, result->mutable_data(), row_major);
break;
case Type::UINT32:
ConvertColumnsToTensor<UInt32Type>(*this, result->mutable_data(), row_major);
break;
case Type::UINT64:
ConvertColumnsToTensor<UInt64Type>(*this, result->mutable_data(), row_major);
break;
case Type::INT8:
ConvertColumnsToTensor<Int8Type>(*this, result->mutable_data(), row_major);
break;
case Type::INT16:
ConvertColumnsToTensor<Int16Type>(*this, result->mutable_data(), row_major);
break;
case Type::INT32:
ConvertColumnsToTensor<Int32Type>(*this, result->mutable_data(), row_major);
break;
case Type::INT64:
ConvertColumnsToTensor<Int64Type>(*this, result->mutable_data(), row_major);
break;
case Type::FLOAT:
ConvertColumnsToTensor<FloatType>(*this, result->mutable_data(), row_major);
break;
case Type::DOUBLE:
ConvertColumnsToTensor<DoubleType>(*this, result->mutable_data(), row_major);
break;
default:
return Status::TypeError("DataType is not supported: ", result_type->ToString());
}

// Construct Tensor object
const auto& fixed_width_type =
internal::checked_cast<const FixedWidthType&>(*result_type);
std::vector<int64_t> shape = {num_rows(), num_columns()};
std::vector<int64_t> strides;
std::shared_ptr<Tensor> tensor;

if (row_major) {
ARROW_RETURN_NOT_OK(
internal::ComputeRowMajorStrides(fixed_width_type, shape, &strides));
} else {
ARROW_RETURN_NOT_OK(
internal::ComputeColumnMajorStrides(fixed_width_type, shape, &strides));
}
ARROW_ASSIGN_OR_RAISE(tensor,
Tensor::Make(result_type, std::move(result), shape, strides));
ARROW_RETURN_NOT_OK(
internal::RecordBatchToTensor(*this, null_to_nan, row_major, pool, &tensor));
return tensor;
}

Expand Down
Loading