Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38007: [C++] Add VariableShapeTensor implementation #38008

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
ed51aeb
Initial commit
rok Aug 15, 2023
6af58b1
Python wrapper
rok Aug 25, 2023
15ebefd
Add VariableShapeTensorArray::ToTensor(i)
rok Sep 3, 2023
511d687
:Add ragged_dimensions
rok Sep 12, 2023
96f2dec
Replace ragged_dimensions with uniform_dimensions
rok Sep 15, 2023
1e2eef6
Add example for explanation
rok Sep 15, 2023
d0a2651
Add uniform_shape parameter
rok Sep 24, 2023
3c93765
Apply suggestions from code review
rok Sep 25, 2023
9273986
Post rebase
rok Oct 11, 2023
a09c1bd
Remove uniform_dimensions, fix python test
rok Oct 12, 2023
f24fa7d
lint
rok Oct 12, 2023
935c84b
uniform_shape values are optional
rok Oct 12, 2023
7652d6d
Add scalar test
rok Oct 29, 2023
522fa59
Create Tensor from scalar
rok Oct 30, 2023
831c99d
Move get_tensor logic to cpp
rok Nov 28, 2023
1da9790
slice buffer with array offset
rok Nov 28, 2023
2edf17b
Update cpp/src/arrow/extension/variable_shape_tensor.h
rok Nov 28, 2023
95eea74
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
7b0ab6c
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
58e9365
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
ae3a5c6
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
9a90eb1
Review feedback
rok Nov 28, 2023
75019c8
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 29, 2023
585699d
Review feedback
rok Nov 29, 2023
b3945eb
import and uint32->int32
rok Nov 29, 2023
f2e33a2
permutation check
rok Nov 29, 2023
d569a2d
Remove serialization from cython, lint
rok Nov 29, 2023
4ec4039
Review feedback
rok Nov 30, 2023
5d52992
ndim initializer
rok Nov 30, 2023
58ab37c
Test null values
rok Nov 30, 2023
8177874
Remove one GetTensor code paths, permutation handling
rok Dec 2, 2023
e75905c
Allow arbitrary memory layout
rok Dec 3, 2023
86c6931
fix permutation check
rok Dec 3, 2023
c59e7dd
lint
rok Dec 3, 2023
567e483
lint
rok Dec 3, 2023
76b113e
roundtrip strided
rok Dec 4, 2023
6ba0369
Apply suggestions from code review
rok Dec 13, 2023
e67c3a2
remove array.gettensor, simlify
rok Dec 13, 2023
09afbc5
work
rok Dec 14, 2023
f9aaa28
Add repr
rok Dec 14, 2023
a9684ee
Review feedback
rok Dec 14, 2023
f178e4b
GetTensor->MakeTensor, static
rok Dec 23, 2023
80cc733
Better permutations check
rok Dec 23, 2023
252a643
post rebase changes
rok Feb 8, 2024
39e89eb
work
rok Feb 9, 2024
7476321
ToString new parameter
rok Mar 4, 2024
c9dd006
Remove Python bindings
rok Mar 4, 2024
b716708
Review feedback
rok Mar 16, 2024
977d19a
Use TensorFromJSON
rok Mar 16, 2024
0e113ac
lint
rok Mar 17, 2024
33ceee6
Apply suggestions from code review
rok Mar 27, 2024
26b7698
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Mar 27, 2024
b49021e
fix
rok Mar 27, 2024
8c2f0a3
Review feedback
rok Mar 27, 2024
aa3d29a
mingw64 issue
rok Mar 28, 2024
3741041
refactor ComputeStrides
rok Mar 29, 2024
dfd5fbe
Change to ComputeStrides
rok Apr 1, 2024
341757b
Change ToTensor
rok Apr 1, 2024
64d73fb
Refactoring ComputeStrides
rok Apr 2, 2024
c3dab58
Move RoundtripBatch to gtest_util.cc
rok Apr 14, 2024
2cfc678
Post rebase changes
rok Jun 6, 2024
f391ab6
Post rebase changes
rok Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,8 @@ if(ARROW_JSON)
arrow_add_object_library(ARROW_JSON
extension/fixed_shape_tensor.cc
extension/opaque.cc
extension/tensor_internal.cc
extension/variable_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc)

if(ARROW_JSON)
list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc)
list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc)
endif()

add_arrow_test(test
Expand Down
61 changes: 7 additions & 54 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,52 +37,7 @@

namespace rj = arrow::rapidjson;

namespace arrow {

namespace extension {

namespace {

Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(type, shape, strides);
}

const int byte_width = type.byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
internal::Permute(permutation, strides);

return Status::OK();
}

} // namespace
namespace arrow::extension {

bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
Expand Down Expand Up @@ -237,7 +192,8 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
}

std::vector<int64_t> strides;
RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
RETURN_NOT_OK(
internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides));
const auto start_position = array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
Expand Down Expand Up @@ -376,9 +332,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
internal::Permute<int64_t>(permutation, &shape);

std::vector<int64_t> tensor_strides;
const auto* fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get());
ARROW_RETURN_NOT_OK(
ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));
internal::ComputeStrides(value_type, shape, permutation, &tensor_strides));

const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
ARROW_ASSIGN_OR_RAISE(
Expand Down Expand Up @@ -412,10 +367,9 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
auto value_type = internal::checked_cast<FixedWidthType*>(this->value_type_.get());
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(
ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides));
ARROW_CHECK_OK(internal::ComputeStrides(this->value_type_, this->shape(),
this->permutation(), &tensor_strides));
strides_ = tensor_strides;
}
return strides_;
Expand All @@ -430,5 +384,4 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
return maybe_type.MoveValueUnsafe();
}

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
6 changes: 2 additions & 4 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

#include "arrow/extension_type.h"

namespace arrow {
namespace extension {
namespace arrow::extension {

class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
Expand Down Expand Up @@ -126,5 +125,4 @@ ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
Loading
Loading