Skip to content

Commit 9fc3b0c

Browse files
committed
feat: add dictionary_columns to scan API for memory-efficient string reads
1 parent 1a54e9c commit 9fc3b0c

4 files changed

Lines changed: 270 additions & 3 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1614,8 +1614,13 @@ def _task_to_record_batches(
16141614
partition_spec: PartitionSpec | None = None,
16151615
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
16161616
downcast_ns_timestamp_to_us: bool | None = None,
1617+
dictionary_columns: tuple[str, ...] | None = None,
16171618
) -> Iterator[pa.RecordBatch]:
1618-
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1619+
# Only pass dictionary_columns for Parquet — ORC does not support this kwarg.
1620+
format_kwargs: dict[str, Any] = {"pre_buffer": True, "buffer_size": ONE_MEGABYTE * 8}
1621+
if dictionary_columns and task.file.file_format == FileFormat.PARQUET:
1622+
format_kwargs["dictionary_columns"] = dictionary_columns
1623+
arrow_format = _get_file_format(task.file.file_format, **format_kwargs)
16191624
with io.new_input(task.file.file_path).open() as fin:
16201625
fragment = arrow_format.make_fragment(fin)
16211626
physical_schema = fragment.physical_schema
@@ -1718,6 +1723,7 @@ class ArrowScan:
17181723
_case_sensitive: bool
17191724
_limit: int | None
17201725
_downcast_ns_timestamp_to_us: bool | None
1726+
_dictionary_columns: tuple[str, ...] | None
17211727
"""Scan the Iceberg Table and create an Arrow construct.
17221728
17231729
Attributes:
@@ -1737,6 +1743,8 @@ def __init__(
17371743
row_filter: BooleanExpression,
17381744
case_sensitive: bool = True,
17391745
limit: int | None = None,
1746+
*,
1747+
dictionary_columns: tuple[str, ...] | None = None,
17401748
) -> None:
17411749
self._table_metadata = table_metadata
17421750
self._io = io
@@ -1745,6 +1753,7 @@ def __init__(
17451753
self._case_sensitive = case_sensitive
17461754
self._limit = limit
17471755
self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE)
1756+
self._dictionary_columns = dictionary_columns
17481757

17491758
@property
17501759
def _projected_field_ids(self) -> set[int]:
@@ -1773,6 +1782,15 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
17731782
ValueError: When a field type in the file cannot be projected to the schema type
17741783
"""
17751784
arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False)
1785+
if self._dictionary_columns:
1786+
dict_cols_set = set(self._dictionary_columns)
1787+
arrow_schema = pa.schema(
1788+
[
1789+
field.with_type(pa.dictionary(pa.int32(), field.type)) if field.name in dict_cols_set else field
1790+
for field in arrow_schema
1791+
],
1792+
metadata=arrow_schema.metadata,
1793+
)
17761794

17771795
batches = self.to_record_batches(tasks)
17781796
try:
@@ -1855,6 +1873,7 @@ def _record_batches_from_scan_tasks_and_deletes(
18551873
self._table_metadata.specs().get(task.file.spec_id),
18561874
self._table_metadata.format_version,
18571875
self._downcast_ns_timestamp_to_us,
1876+
dictionary_columns=self._dictionary_columns,
18581877
)
18591878
for batch in batches:
18601879
if self._limit is not None:

pyiceberg/table/__init__.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,7 @@ def scan(
11211121
snapshot_id: int | None = None,
11221122
options: Properties = EMPTY_DICT,
11231123
limit: int | None = None,
1124+
dictionary_columns: tuple[str, ...] | None = None,
11241125
) -> DataScan:
11251126
"""Fetch a DataScan based on the table's current metadata.
11261127
@@ -1147,6 +1148,13 @@ def scan(
11471148
An integer representing the number of rows to
11481149
return in the scan result. If None, fetches all
11491150
matching rows.
1151+
dictionary_columns:
1152+
A tuple of column names that PyArrow should read as
1153+
dictionary-encoded (DictionaryArray). Reduces memory
1154+
usage for columns with large or repeated string values
1155+
(e.g. large JSON blobs). Only applies to Parquet files;
1156+
silently ignored for ORC. Columns absent from the file
1157+
are silently skipped. Default is None (no dictionary encoding).
11501158
11511159
Returns:
11521160
A DataScan based on the table's current metadata.
@@ -1162,6 +1170,7 @@ def scan(
11621170
limit=limit,
11631171
catalog=self.catalog,
11641172
table_identifier=self._identifier,
1173+
dictionary_columns=dictionary_columns,
11651174
)
11661175

11671176
@property
@@ -1664,6 +1673,7 @@ def scan(
16641673
snapshot_id: int | None = None,
16651674
options: Properties = EMPTY_DICT,
16661675
limit: int | None = None,
1676+
dictionary_columns: tuple[str, ...] | None = None,
16671677
) -> DataScan:
16681678
raise ValueError("Cannot scan a staged table")
16691679

@@ -1916,6 +1926,36 @@ def _min_sequence_number(manifests: list[ManifestFile]) -> int:
19161926

19171927

19181928
class DataScan(TableScan):
1929+
dictionary_columns: tuple[str, ...] | None
1930+
1931+
def __init__(
1932+
self,
1933+
table_metadata: TableMetadata,
1934+
io: FileIO,
1935+
row_filter: str | BooleanExpression = ALWAYS_TRUE,
1936+
selected_fields: tuple[str, ...] = ("*",),
1937+
case_sensitive: bool = True,
1938+
snapshot_id: int | None = None,
1939+
options: Properties = EMPTY_DICT,
1940+
limit: int | None = None,
1941+
catalog: Catalog | None = None,
1942+
table_identifier: Identifier | None = None,
1943+
dictionary_columns: tuple[str, ...] | None = None,
1944+
) -> None:
1945+
super().__init__(
1946+
table_metadata=table_metadata,
1947+
io=io,
1948+
row_filter=row_filter,
1949+
selected_fields=selected_fields,
1950+
case_sensitive=case_sensitive,
1951+
snapshot_id=snapshot_id,
1952+
options=options,
1953+
limit=limit,
1954+
catalog=catalog,
1955+
table_identifier=table_identifier,
1956+
)
1957+
self.dictionary_columns = dictionary_columns
1958+
19191959
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
19201960
project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id], self.case_sensitive)
19211961
return project(self.row_filter)
@@ -2113,7 +2153,13 @@ def to_arrow(self) -> pa.Table:
21132153
from pyiceberg.io.pyarrow import ArrowScan
21142154

21152155
return ArrowScan(
2116-
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2156+
self.table_metadata,
2157+
self.io,
2158+
self.projection(),
2159+
self.row_filter,
2160+
self.case_sensitive,
2161+
self.limit,
2162+
dictionary_columns=self.dictionary_columns,
21172163
).to_table(self.plan_files())
21182164

21192165
def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
@@ -2132,8 +2178,29 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
21322178
from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow
21332179

21342180
target_schema = schema_to_pyarrow(self.projection())
2181+
2182+
# When dictionary_columns is set, PyArrow returns DictionaryArray for those columns.
2183+
# target_schema uses plain string types, so .cast(target_schema) would silently decode
2184+
# them back to plain strings. Rebuild target_schema with dictionary types for the listed
2185+
# columns so from_batches and cast both preserve the encoding.
2186+
if self.dictionary_columns:
2187+
dict_cols_set = set(self.dictionary_columns)
2188+
target_schema = pa.schema(
2189+
[
2190+
field.with_type(pa.dictionary(pa.int32(), field.type)) if field.name in dict_cols_set else field
2191+
for field in target_schema
2192+
],
2193+
metadata=target_schema.metadata,
2194+
)
2195+
21352196
batches = ArrowScan(
2136-
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2197+
self.table_metadata,
2198+
self.io,
2199+
self.projection(),
2200+
self.row_filter,
2201+
self.case_sensitive,
2202+
self.limit,
2203+
dictionary_columns=self.dictionary_columns,
21372204
).to_record_batches(self.plan_files())
21382205

21392206
return pa.RecordBatchReader.from_batches(

tests/io/test_pyarrow.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,6 +3152,168 @@ def _expected_batch(unit: str) -> pa.RecordBatch:
31523152
assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result)
31533153

31543154

3155+
def test_task_to_record_batches_dictionary_columns(tmpdir: str) -> None:
3156+
"""dictionary_columns causes the column to be read as DictionaryArray, saving memory."""
3157+
arrow_table = pa.table(
3158+
{"json_col": pa.array(["large-json-1", "large-json-2", "large-json-1"], type=pa.string())},
3159+
schema=pa.schema([pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]),
3160+
)
3161+
data_file = _write_table_to_data_file(f"{tmpdir}/test_dictionary_columns.parquet", arrow_table.schema, arrow_table)
3162+
table_schema = Schema(NestedField(1, "json_col", StringType(), required=False))
3163+
3164+
batches = list(
3165+
_task_to_record_batches(
3166+
PyArrowFileIO(),
3167+
FileScanTask(data_file),
3168+
bound_row_filter=AlwaysTrue(),
3169+
projected_schema=table_schema,
3170+
table_schema=table_schema,
3171+
projected_field_ids={1},
3172+
positional_deletes=None,
3173+
case_sensitive=True,
3174+
dictionary_columns=("json_col",),
3175+
)
3176+
)
3177+
3178+
assert len(batches) == 1, "Expected exactly one record batch"
3179+
col = batches[0].column("json_col")
3180+
assert pa.types.is_dictionary(col.type), (
3181+
f"Expected DictionaryArray for 'json_col' when dictionary_columns is set, got {col.type}"
3182+
)
3183+
3184+
3185+
def test_task_to_record_batches_no_dictionary_columns_by_default(tmpdir: str) -> None:
3186+
"""Without dictionary_columns, string columns are returned as plain StringArray — default unchanged."""
3187+
arrow_table = pa.table(
3188+
{"json_col": pa.array(["a", "b", "c"], type=pa.string())},
3189+
schema=pa.schema([pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]),
3190+
)
3191+
data_file = _write_table_to_data_file(f"{tmpdir}/test_no_dictionary_default.parquet", arrow_table.schema, arrow_table)
3192+
table_schema = Schema(NestedField(1, "json_col", StringType(), required=False))
3193+
3194+
batches = list(
3195+
_task_to_record_batches(
3196+
PyArrowFileIO(),
3197+
FileScanTask(data_file),
3198+
bound_row_filter=AlwaysTrue(),
3199+
projected_schema=table_schema,
3200+
table_schema=table_schema,
3201+
projected_field_ids={1},
3202+
positional_deletes=None,
3203+
case_sensitive=True,
3204+
# dictionary_columns intentionally omitted — must not change behavior
3205+
)
3206+
)
3207+
3208+
assert len(batches) == 1, "Expected exactly one record batch"
3209+
col = batches[0].column("json_col")
3210+
assert not pa.types.is_dictionary(col.type), f"Expected plain StringArray by default, got {col.type}"
3211+
3212+
3213+
def test_arrow_scan_to_table_with_dictionary_columns(tmpdir: str) -> None:
3214+
"""ArrowScan.to_table() with dictionary_columns: named column is DictionaryArray, others are not."""
3215+
import pyarrow.parquet as pq
3216+
3217+
arrow_schema = pa.schema(
3218+
[
3219+
pa.field("id", pa.int32(), metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),
3220+
pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "2"}),
3221+
]
3222+
)
3223+
arrow_table = pa.table(
3224+
{
3225+
"id": pa.array([1, 2, 3], type=pa.int32()),
3226+
"json_col": pa.array(['{"x": 1}', '{"x": 2}', '{"x": 1}'], type=pa.string()),
3227+
},
3228+
schema=arrow_schema,
3229+
)
3230+
filepath = f"{tmpdir}/test_e2e_dictionary.parquet"
3231+
with pq.ParquetWriter(filepath, arrow_schema) as writer:
3232+
writer.write_table(arrow_table)
3233+
3234+
iceberg_schema = Schema(
3235+
NestedField(1, "id", IntegerType(), required=False),
3236+
NestedField(2, "json_col", StringType(), required=False),
3237+
)
3238+
data_file = DataFile.from_args(
3239+
content=DataFileContent.DATA,
3240+
file_path=filepath,
3241+
file_format=FileFormat.PARQUET,
3242+
partition={},
3243+
record_count=3,
3244+
file_size_in_bytes=100,
3245+
)
3246+
data_file.spec_id = 0
3247+
3248+
result = ArrowScan(
3249+
TableMetadataV2(
3250+
location="file://a/b/",
3251+
last_column_id=2,
3252+
format_version=2,
3253+
schemas=[iceberg_schema],
3254+
partition_specs=[PartitionSpec()],
3255+
),
3256+
PyArrowFileIO(),
3257+
iceberg_schema,
3258+
AlwaysTrue(),
3259+
dictionary_columns=("json_col",),
3260+
).to_table(tasks=[FileScanTask(data_file)])
3261+
3262+
assert pa.types.is_dictionary(result.schema.field("json_col").type), (
3263+
f"Expected DictionaryArray for 'json_col', got {result.schema.field('json_col').type}"
3264+
)
3265+
assert not pa.types.is_dictionary(result.schema.field("id").type), "Non-listed column 'id' should NOT be dictionary-encoded"
3266+
3267+
3268+
def test_arrow_scan_to_record_batches_preserves_dictionary_encoding(tmpdir: str) -> None:
3269+
"""ArrowScan.to_record_batches() must preserve DictionaryArray — not decode back to plain string."""
3270+
import pyarrow.parquet as pq
3271+
3272+
arrow_schema = pa.schema(
3273+
[
3274+
pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),
3275+
]
3276+
)
3277+
arrow_table = pa.table(
3278+
{"json_col": pa.array(['{"a": 1}', '{"b": 2}'], type=pa.string())},
3279+
schema=arrow_schema,
3280+
)
3281+
filepath = f"{tmpdir}/test_batch_reader_dict.parquet"
3282+
with pq.ParquetWriter(filepath, arrow_schema) as writer:
3283+
writer.write_table(arrow_table)
3284+
3285+
iceberg_schema = Schema(NestedField(1, "json_col", StringType(), required=False))
3286+
data_file = DataFile.from_args(
3287+
content=DataFileContent.DATA,
3288+
file_path=filepath,
3289+
file_format=FileFormat.PARQUET,
3290+
partition={},
3291+
record_count=2,
3292+
file_size_in_bytes=100,
3293+
)
3294+
data_file.spec_id = 0
3295+
3296+
batches = list(
3297+
ArrowScan(
3298+
TableMetadataV2(
3299+
location="file://a/b/",
3300+
last_column_id=1,
3301+
format_version=2,
3302+
schemas=[iceberg_schema],
3303+
partition_specs=[PartitionSpec()],
3304+
),
3305+
PyArrowFileIO(),
3306+
iceberg_schema,
3307+
AlwaysTrue(),
3308+
dictionary_columns=("json_col",),
3309+
).to_record_batches(tasks=[FileScanTask(data_file)])
3310+
)
3311+
3312+
assert len(batches) >= 1, "Expected at least one record batch"
3313+
col = batches[0].column("json_col")
3314+
assert pa.types.is_dictionary(col.type), f"DictionaryArray must be preserved through to_record_batches, got {col.type}"
3315+
3316+
31553317
def test_parse_location_defaults() -> None:
31563318
"""Test that parse_location uses defaults."""
31573319

tests/table/test_init.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,25 @@ def test_table_scan_select(table_fixture: Table) -> None:
274274
assert scan.select("a", "c").select("a").selected_fields == ("a",)
275275

276276

277+
def test_table_scan_dictionary_columns_default(table_v2: Table) -> None:
278+
scan = table_v2.scan()
279+
assert scan.dictionary_columns is None, "dictionary_columns should default to None"
280+
281+
282+
def test_table_scan_dictionary_columns_set(table_v2: Table) -> None:
283+
scan = table_v2.scan(dictionary_columns=("json_col", "other_col"))
284+
assert scan.dictionary_columns == ("json_col", "other_col"), "dictionary_columns should be stored on the scan"
285+
286+
287+
def test_table_scan_dictionary_columns_preserved_on_update(table_v2: Table) -> None:
288+
scan = table_v2.scan(dictionary_columns=("json_col",))
289+
updated = scan.update(limit=10)
290+
assert updated.dictionary_columns == ("json_col",), (
291+
"dictionary_columns must survive .update() — TableScan.update() uses inspect.signature "
292+
"so DataScan.__init__ must declare and store it"
293+
)
294+
295+
277296
def test_table_scan_row_filter(table_v2: Table) -> None:
278297
scan = table_v2.scan()
279298
assert scan.row_filter == AlwaysTrue()

0 commit comments

Comments
 (0)