Skip to content

Commit

Permalink
Add VariableShapeTensorScalar
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Oct 29, 2023
1 parent c6bbe04 commit 19e46a2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,28 @@ cdef class ExtensionScalar(Scalar):
return pyarrow_wrap_scalar(<shared_ptr[CScalar]> sp_scalar)


class VariableShapeTensorScalar(ExtensionScalar):
"""
Concrete class for variable shape tensor extension scalar.
"""

def to_numpy_ndarray(self):
# TODO: allow any permutation
"""
Convert variable shape tensor extension scalar to a numpy array.
Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``).
"""

if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))):
shape = self.get("shape")
np_flat = np.asarray(self.get("values").flatten())
numpy_tensor = np_flat.reshape(tuple(shape))
return numpy_tensor
else:
raise ValueError(
'Only non-permuted tensors can be converted to numpy tensors.')


cdef dict _scalar_classes = {
_Type_BOOL: BooleanScalar,
_Type_UINT8: UInt8Scalar,
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,10 @@ cdef class VariableShapeTensorType(BaseExtensionType):
self.dim_names, self.permutation,
self.uniform_shape)

def __variable_ext_scalar_class__(self):
return VariableShapeTensorScalar


cdef class FixedShapeTensorType(BaseExtensionType):
"""
Concrete class for fixed shape tensor extension type.
Expand Down

0 comments on commit 19e46a2

Please sign in to comment.