@@ -1499,78 +1499,23 @@ cdef class FixedShapeTensorType(BaseExtensionType):
14991499 """
15001500 Concrete class for fixed shape tensor extension type.
15011501
1502- Parameters
1503- ----------
1504- value_type : DataType
1505- Data type of individual tensor elements.
1506- shape : tuple
1507- The physical shape of the contained tensors.
1508- dim_names : tuple
1509- Explicit names to tensor dimensions.
1510- permutation : tuple
1511- Indices of the desired ordering of the original dimensions.
1512-
15131502 Examples
15141503 --------
1515- >>> import pyarrow as pa
1516-
1517- Create fixed shape tensor extension type:
1504+ Create an instance of fixed shape tensor extension type:
15181505
1519- >>> tensor_type = pa.FixedShapeTensorType(pa.int32(), [2, 2])
1520- >>> tensor_type
1506+ >>> import pyarrow as pa
1507+ >>> pa.fixedshapetensor(pa.int32(), [2, 2])
15211508 FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
15221509
1523- Inspect the data type:
1524-
1525- >>> tensor_type.value_type
1526- DataType(int32)
1527- >>> tensor_type.shape
1528- [2, 2]
1529-
1530- Create a fixed shape tensor extension type with names of tensor dimensions:
1531-
1532- >>> tensor_type = pa.FixedShapeTensorType(pa.int8(), (2, 2, 3), dim_names=['C', 'H', 'W'])
1533- >>> tensor_type.dim_names
1534- [b'C', b'H', b'W']
1510+ Create an instance of fixed shape tensor extension type with
1511+ permutation:
15351512
1536- Create a fixed shape tensor extension type with permutation:
1537- >>> tensor_type = pa.FixedShapeTensorType(pa.int8(), (2, 2, 3), permutation=[0, 2, 1])
1513+ >>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
1514+ ... permutation=[0, 2, 1])
15381515 >>> tensor_type.permutation
15391516 [0, 2, 1]
15401517 """
15411518
1542- def __init__ (self , DataType value_type , shape , dim_names = None , permutation = None ):
1543- """
1544- Initialize an fixed shape tensor extension type instance.
1545-
1546- This should be called at the end of the subclass'
1547- ``__init__`` method.
1548- """
1549- cdef:
1550- vector[int64_t] c_shape
1551- vector[int64_t] c_permutation
1552- vector[c_string] c_dim_names
1553- shared_ptr[CDataType] tensor_ext_type
1554-
1555- assert value_type is not None
1556- assert shape is not None
1557-
1558- for i in shape:
1559- c_shape.push_back(i)
1560-
1561- if permutation is not None :
1562- for i in permutation:
1563- c_permutation.push_back(i)
1564-
1565- if dim_names is not None :
1566- for x in dim_names:
1567- c_dim_names.push_back(tobytes(x))
1568-
1569- tensor_ext_type = GetResultValue(CFixedShapeTensorType.Make(
1570- value_type.sp_type, c_shape, c_permutation, c_dim_names))
1571-
1572- self .init(tensor_ext_type)
1573-
15741519 cdef void init(self , const shared_ptr[CDataType]& type ) except * :
15751520 BaseExtensionType.init(self , type )
15761521 self .tensor_ext_type = < const CFixedShapeTensorType* > type .get()
@@ -1607,17 +1552,15 @@ cdef class FixedShapeTensorType(BaseExtensionType):
16071552 """
16081553 Serialized representation of metadata to reconstruct the type object.
16091554 """
1610- metadata = self .tensor_ext_type.Serialize()
1611- return metadata
1555+ return self .tensor_ext_type.Serialize()
16121556
16131557 @classmethod
16141558 def __arrow_ext_deserialize__ (self , storage_type , serialized ):
16151559 """
16161560 Return an FixedShapeTensor type instance from the storage type and serialized
16171561 metadata.
16181562 """
1619- tensor_ext_type = self .tensor_ext_type.Deserialize(storage_type, serialized)
1620- return tensor_ext_type
1563+ return self .tensor_ext_type.Deserialize(storage_type, serialized)
16211564
16221565 def __arrow_ext_class__ (self ):
16231566 return FixedShapeTensorArray
@@ -4672,6 +4615,100 @@ def run_end_encoded(run_end_type, value_type):
46724615 return pyarrow_wrap_data_type(ree_type)
46734616
46744617
4618+ def fixedshapetensor (DataType value_type , shape , dim_names = None , permutation = None ):
4619+ """
4620+ Create instance of fixed shape tensor extension type with shape and optional
4621+ names of tensor dimensions and indices of the desired ordering.
4622+
4623+ Parameters
4624+ ----------
4625+ value_type : DataType
4626+ Data type of individual tensor elements.
4627+ shape : tuple
4628+ The physical shape of the contained tensors.
4629+ dim_names : tuple, default None
4630+ Explicit names to tensor dimensions.
4631+ permutation : tuple, default None
4632+ Indices of the desired ordering of the original dimensions.
4633+
4634+ Examples
4635+ --------
4636+ Create an instance of fixed shape tensor extension type:
4637+
4638+ >>> import pyarrow as pa
4639+ >>> tensor_type = pa.fixedshapetensor(pa.int32(), [2, 2])
4640+ >>> tensor_type
4641+ FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
4642+
4643+ Inspect the data type:
4644+
4645+ >>> tensor_type.value_type
4646+ DataType(int32)
4647+ >>> tensor_type.shape
4648+ [2, 2]
4649+
4650+ Create a table with fixed shape tensor extension array:
4651+
4652+ >>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
4653+ >>> storage = pa.array(arr, pa.list_(pa.int32(), 4))
4654+ >>> tensor = pa.ExtensionArray.from_storage(tensor_type, storage)
4655+ >>> pa.table([tensor], names=["tensor_array"])
4656+ pyarrow.Table
4657+ tensor_array: extension<arrow.fixed_shape_tensor>
4658+ ----
4659+ tensor_array: [[[1,2,3,4],[10,20,30,40],[100,200,300,400]]]
4660+
4661+ Create an instance of fixed shape tensor extension type with names
4662+ of tensor dimensions:
4663+
4664+ >>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
4665+ ... dim_names=['C', 'H', 'W'])
4666+ >>> tensor_type.dim_names
4667+ [b'C', b'H', b'W']
4668+
4669+ Create an instance of fixed shape tensor extension type with
4670+ permutation:
4671+
4672+ >>> tensor_type = pa.fixedshapetensor(pa.int8(), (2, 2, 3),
4673+ ... permutation=[0, 2, 1])
4674+ >>> tensor_type.permutation
4675+ [0, 2, 1]
4676+
4677+ Returns
4678+ -------
4679+ type : FixedShapeTensorType
4680+ """
4681+
4682+ cdef:
4683+ vector[int64_t] c_shape
4684+ vector[int64_t] c_permutation
4685+ vector[c_string] c_dim_names
4686+ shared_ptr[CDataType] c_tensor_ext_type
4687+
4688+ assert value_type is not None
4689+ assert shape is not None
4690+
4691+ for i in shape:
4692+ c_shape.push_back(i)
4693+
4694+ if permutation is not None :
4695+ for i in permutation:
4696+ c_permutation.push_back(i)
4697+
4698+ if dim_names is not None :
4699+ for x in dim_names:
4700+ c_dim_names.push_back(tobytes(x))
4701+
4702+ cdef FixedShapeTensorType out = FixedShapeTensorType.__new__ (FixedShapeTensorType)
4703+
4704+ c_tensor_ext_type = GetResultValue(CFixedShapeTensorType.Make(
4705+ value_type.sp_type, c_shape, c_permutation, c_dim_names))
4706+
4707+ out.init(c_tensor_ext_type)
4708+
4709+ return out
4710+
4711+
46754712cdef dict _type_aliases = {
46764713 ' null' : null,
46774714 ' bool' : bool_,
0 commit comments