|
39 | 39 | from pydantic_core import ValidationError |
40 | 40 | from pyspark.sql import SparkSession |
41 | 41 | from pytest_mock.plugin import MockerFixture |
| 42 | +from sqlalchemy import Connection |
| 43 | +from sqlalchemy.sql.expression import text |
42 | 44 |
|
43 | 45 | from pyiceberg.catalog import Catalog, load_catalog |
44 | 46 | from pyiceberg.catalog.hive import HiveCatalog |
|
51 | 53 | from pyiceberg.table import TableProperties |
52 | 54 | from pyiceberg.table.refs import MAIN_BRANCH |
53 | 55 | from pyiceberg.table.sorting import SortDirection, SortField, SortOrder |
54 | | -from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform, Transform |
55 | | -from pyiceberg.types import ( |
56 | | - DateType, |
57 | | - DecimalType, |
58 | | - DoubleType, |
59 | | - IntegerType, |
60 | | - ListType, |
61 | | - LongType, |
62 | | - NestedField, |
63 | | - StringType, |
64 | | - UUIDType, |
65 | | -) |
| 56 | +from pyiceberg.transforms import BucketTransform, DayTransform, HourTransform, IdentityTransform, Transform |
| 57 | +from pyiceberg.types import DateType, DecimalType, DoubleType, IntegerType, ListType, LongType, NestedField, StringType, UUIDType |
66 | 58 | from utils import _create_table |
67 | 59 |
|
68 | 60 |
|
@@ -2014,6 +2006,7 @@ def test_read_write_decimals(session_catalog: Catalog) -> None: |
2014 | 2006 | assert tbl.scan().to_arrow() == arrow_table |
2015 | 2007 |
|
2016 | 2008 |
|
| 2009 | +@pytest.mark.skip("UUID BucketTransform is not supported in Spark Iceberg 1.9.2 yet") |
2017 | 2010 | @pytest.mark.integration |
2018 | 2011 | @pytest.mark.parametrize( |
2019 | 2012 | "transform", |
@@ -2067,6 +2060,64 @@ def test_uuid_partitioning(session_catalog: Catalog, spark: SparkSession, transf |
2067 | 2060 | assert lhs == rhs |
2068 | 2061 |
|
2069 | 2062 |
|
| 2063 | +@pytest.mark.integration_trino |
| 2064 | +@pytest.mark.integration |
| 2065 | +@pytest.mark.parametrize( |
| 2066 | + "transform", |
| 2067 | + [IdentityTransform(), BucketTransform(32)], |
| 2068 | +) |
| 2069 | +@pytest.mark.parametrize( |
| 2070 | + "catalog, trino_conn", |
| 2071 | + [ |
| 2072 | + (pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("trino_hive_conn")), |
| 2073 | + (pytest.lazy_fixture("session_catalog"), pytest.lazy_fixture("trino_rest_conn")), |
| 2074 | + ], |
| 2075 | +) |
| 2076 | +def test_uuid_partitioning_with_trino(catalog: Catalog, trino_conn: Connection, transform: Transform) -> None: # type: ignore |
| 2077 | + identifier = f"default.test_uuid_partitioning_{str(transform).replace('[32]', '')}" |
| 2078 | + |
| 2079 | + schema = Schema(NestedField(field_id=1, name="uuid", field_type=UUIDType(), required=True)) |
| 2080 | + |
| 2081 | + try: |
| 2082 | + catalog.drop_table(identifier=identifier) |
| 2083 | + except NoSuchTableError: |
| 2084 | + pass |
| 2085 | + |
| 2086 | + partition_spec = PartitionSpec( |
| 2087 | + PartitionField(source_id=1, field_id=1000, transform=transform, name=f"uuid_{str(transform).replace('[32]', '')}") |
| 2088 | + ) |
| 2089 | + |
| 2090 | + import pyarrow as pa |
| 2091 | + |
| 2092 | + arr_table = pa.Table.from_pydict( |
| 2093 | + { |
| 2094 | + "uuid": [ |
| 2095 | + uuid.UUID("00000000-0000-0000-0000-000000000000").bytes, |
| 2096 | + uuid.UUID("11111111-1111-1111-1111-111111111111").bytes, |
| 2097 | + ], |
| 2098 | + }, |
| 2099 | + schema=pa.schema( |
| 2100 | + [ |
| 2101 | + # Uuid not yet supported, so we have to stick with `binary(16)` |
| 2102 | + # https://github.com/apache/arrow/issues/46468 |
| 2103 | + pa.field("uuid", pa.binary(16), nullable=False), |
| 2104 | + ] |
| 2105 | + ), |
| 2106 | + ) |
| 2107 | + |
| 2108 | + tbl = catalog.create_table( |
| 2109 | + identifier=identifier, |
| 2110 | + schema=schema, |
| 2111 | + partition_spec=partition_spec, |
| 2112 | + ) |
| 2113 | + |
| 2114 | + tbl.append(arr_table) |
| 2115 | + rows = trino_conn.execute(text(f"SELECT * FROM {identifier}")).fetchall() |
| 2116 | + lhs = sorted([r[0] for r in rows]) |
| 2117 | + rhs = sorted([u.as_py() for u in tbl.scan().to_arrow()["uuid"].combine_chunks()]) |
| 2118 | + assert lhs == rhs |
| 2119 | + |
| 2120 | + |
2070 | 2121 | @pytest.mark.integration |
2071 | 2122 | def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
2072 | 2123 | identifier = "default.test_avro_compression_codecs" |
|
0 commit comments