Skip to content

Commit 1bdba1d

Browse files
committed
Add pa.fixedshapetensor factory function and update docstring examples
1 parent a6292f8 commit 1bdba1d

2 files changed

Lines changed: 104 additions & 66 deletions

File tree

python/pyarrow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def print_entry(label, value):
170170
union, sparse_union, dense_union,
171171
dictionary,
172172
run_end_encoded,
173+
fixedshapetensor,
173174
field,
174175
type_for_alias,
175176
DataType, DictionaryType, StructType,

python/pyarrow/types.pxi

Lines changed: 103 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,78 +1499,23 @@ cdef class FixedShapeTensorType(BaseExtensionType):
14991499
"""
15001500
Concrete class for fixed shape tensor extension type.
15011501
1502-
Parameters
1503-
----------
1504-
value_type : DataType
1505-
Data type of individual tensor elements.
1506-
shape : tuple
1507-
The physical shape of the contained tensors.
1508-
dim_names : tuple
1509-
Explicit names to tensor dimensions.
1510-
permutation : tuple
1511-
Indices of the desired ordering of the original dimensions.
1512-
15131502
Examples
15141503
--------
1515-
>>> import pyarrow as pa
1516-
1517-
Create fixed shape tensor extension type:
1504+
Create an instance of fixed shape tensor extension type:
15181505
1519-
>>> tensor_type = pa.FixedShapeTensorType(pa.int32(), [2, 2])
1520-
>>> tensor_type
1506+
>>> import pyarrow as pa
1507+
>>> pa.fixedshapetensor(pa.int32(), [2, 2])
15211508
FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
15221509
1523-
Inspect the data type:
1524-
1525-
>>> tensor_type.value_type
1526-
DataType(int32)
1527-
>>> tensor_type.shape
1528-
[2, 2]
1529-
1530-
Create a fixed shape tensor extension type with names of tensor dimensions:
1531-
1532-
>>> tensor_type = pa.FixedShapeTensorType(pa.int8(), (2, 2, 3), dim_names=['C', 'H', 'W'])
1533-
>>> tensor_type.dim_names
1534-
[b'C', b'H', b'W']
1510+
Create an instance of fixed shape tensor extension type with
1511+
permutation:
15351512
1536-
Create a fixed shape tensor extension type with permutation:
1537-
>>> tensor_type = pa.FixedShapeTensorType(pa.int8(), (2, 2, 3), permutation=[0, 2, 1])
1513+
>>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
1514+
... permutation=[0, 2, 1])
15381515
>>> tensor_type.permutation
15391516
[0, 2, 1]
15401517
"""
15411518

1542-
def __init__(self, DataType value_type, shape, dim_names=None, permutation=None):
1543-
"""
1544-
Initialize an fixed shape tensor extension type instance.
1545-
1546-
This should be called at the end of the subclass'
1547-
``__init__`` method.
1548-
"""
1549-
cdef:
1550-
vector[int64_t] c_shape
1551-
vector[int64_t] c_permutation
1552-
vector[c_string] c_dim_names
1553-
shared_ptr[CDataType] tensor_ext_type
1554-
1555-
assert value_type is not None
1556-
assert shape is not None
1557-
1558-
for i in shape:
1559-
c_shape.push_back(i)
1560-
1561-
if permutation is not None:
1562-
for i in permutation:
1563-
c_permutation.push_back(i)
1564-
1565-
if dim_names is not None:
1566-
for x in dim_names:
1567-
c_dim_names.push_back(tobytes(x))
1568-
1569-
tensor_ext_type = GetResultValue(CFixedShapeTensorType.Make(
1570-
value_type.sp_type, c_shape, c_permutation, c_dim_names))
1571-
1572-
self.init(tensor_ext_type)
1573-
15741519
cdef void init(self, const shared_ptr[CDataType]& type) except *:
15751520
BaseExtensionType.init(self, type)
15761521
self.tensor_ext_type = <const CFixedShapeTensorType*> type.get()
@@ -1607,17 +1552,15 @@ cdef class FixedShapeTensorType(BaseExtensionType):
16071552
"""
16081553
Serialized representation of metadata to reconstruct the type object.
16091554
"""
1610-
metadata = self.tensor_ext_type.Serialize()
1611-
return metadata
1555+
return self.tensor_ext_type.Serialize()
16121556

16131557
@classmethod
16141558
def __arrow_ext_deserialize__(self, storage_type, serialized):
16151559
"""
16161560
Return an FixedShapeTensor type instance from the storage type and serialized
16171561
metadata.
16181562
"""
1619-
tensor_ext_type = self.tensor_ext_type.Deserialize(storage_type, serialized)
1620-
return tensor_ext_type
1563+
return self.tensor_ext_type.Deserialize(storage_type, serialized)
16211564

16221565
def __arrow_ext_class__(self):
16231566
return FixedShapeTensorArray
@@ -4672,6 +4615,100 @@ def run_end_encoded(run_end_type, value_type):
46724615
return pyarrow_wrap_data_type(ree_type)
46734616

46744617

4618+
def fixedshapetensor(DataType value_type, shape, dim_names=None, permutation=None):
4619+
"""
4620+
Create instance of fixed shape tensor extension type with shape and optional
4621+
names of tensor dimensions and indices of the desired ordering.
4622+
4623+
Parameters
4624+
----------
4625+
value_type : DataType
4626+
Data type of individual tensor elements.
4627+
shape : tuple
4628+
The physical shape of the contained tensors.
4629+
dim_names : tuple, default None
4630+
Explicit names to tensor dimensions.
4631+
permutation : tuple, default None
4632+
Indices of the desired ordering of the original dimensions.
4633+
4634+
Examples
4635+
--------
4636+
Create an instance of fixed shape tensor extension type:
4637+
4638+
>>> import pyarrow as pa
4639+
>>> tensor_type = pa.fixedshapetensor(pa.int32(), [2, 2])
4640+
>>> tensor_type
4641+
FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
4642+
4643+
Inspect the data type:
4644+
4645+
>>> tensor_type.value_type
4646+
DataType(int32)
4647+
>>> tensor_type.shape
4648+
[2, 2]
4649+
4650+
Create a table with fixed shape tensor extension array:
4651+
4652+
>>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
4653+
>>> storage = pa.array(arr, pa.list_(pa.int32(), 4))
4654+
>>> tensor = pa.ExtensionArray.from_storage(tensor_type, storage)
4655+
>>> pa.table([tensor], names=["tensor_array"])
4656+
pyarrow.Table
4657+
tensor_array: extension<arrow.fixed_shape_tensor>
4658+
----
4659+
tensor_array: [[[1,2,3,4],[10,20,30,40],[100,200,300,400]]]
4660+
4661+
Create an instance of fixed shape tensor extension type with names
4662+
of tensor dimensions:
4663+
4664+
>>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
4665+
... dim_names=['C', 'H', 'W'])
4666+
>>> tensor_type.dim_names
4667+
[b'C', b'H', b'W']
4668+
4669+
Create an instance of fixed shape tensor extension type with
4670+
permutation:
4671+
4672+
>>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
4673+
... permutation=[0, 2, 1])
4674+
>>> tensor_type.permutation
4675+
[0, 2, 1]
4676+
4677+
Returns
4678+
-------
4679+
type : FixedShapeTensorType
4680+
"""
4681+
4682+
cdef:
4683+
vector[int64_t] c_shape
4684+
vector[int64_t] c_permutation
4685+
vector[c_string] c_dim_names
4686+
shared_ptr[CDataType] c_tensor_ext_type
4687+
4688+
assert value_type is not None
4689+
assert shape is not None
4690+
4691+
for i in shape:
4692+
c_shape.push_back(i)
4693+
4694+
if permutation is not None:
4695+
for i in permutation:
4696+
c_permutation.push_back(i)
4697+
4698+
if dim_names is not None:
4699+
for x in dim_names:
4700+
c_dim_names.push_back(tobytes(x))
4701+
4702+
cdef FixedShapeTensorType out = FixedShapeTensorType.__new__(FixedShapeTensorType)
4703+
4704+
c_tensor_ext_type = GetResultValue(CFixedShapeTensorType.Make(
4705+
value_type.sp_type, c_shape, c_permutation, c_dim_names))
4706+
4707+
out.init(c_tensor_ext_type)
4708+
4709+
return out
4710+
4711+
46754712
cdef dict _type_aliases = {
46764713
'null': null,
46774714
'bool': bool_,

0 commit comments

Comments
 (0)