Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .ai/skills/check-upstream/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all"
- Python API: `python/datafusion/functions.py` — each function wraps a call to `datafusion._internal.functions`
- Rust bindings: `crates/core/src/functions.rs` — `#[pyfunction]` definitions registered via `init_module()`

**Evaluated and not requiring separate Python exposure:**
- `get_field_path` — already covered by `get_field(expr, *names)`, which takes a
variadic field path and dispatches to the same underlying
`functions::core::get_field` UDF as the upstream `get_field_path` helper.

**How to check:**
1. Fetch the upstream scalar function documentation page
2. Compare against functions listed in `python/datafusion/functions.py` (check the `__all__` list and function definitions)
3. A function is covered if it exists in the Python API — it does NOT need a dedicated Rust `#[pyfunction]`. Many functions are aliases that reuse another function's Rust binding.
4. Only report functions that are missing from the Python `__all__` list / function definitions
4. Check against the "evaluated and not requiring exposure" list before flagging as a gap
5. Only report functions that are missing from the Python `__all__` list / function definitions

### 2. Aggregate Functions

Expand Down
7 changes: 7 additions & 0 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,13 @@ impl PySessionContext {
Ok(())
}

pub fn read_batches(
&self,
batches: PyArrowType<Vec<RecordBatch>>,
) -> PyDataFusionResult<PyDataFrame> {
Ok(PyDataFrame::new(self.ctx.read_batches(batches.0)?))
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (name, path, table_partition_cols=vec![],
parquet_pruning=true,
Expand Down
8 changes: 4 additions & 4 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,10 @@ expr_fn!(union_tag, arg1);
expr_fn!(random);

#[pyfunction]
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
functions::core::get_field()
.call(vec![expr.into(), name.into()])
.into()
fn get_field(expr: PyExpr, names: Vec<PyExpr>) -> PyExpr {
let mut args = vec![expr.into()];
args.extend(names.into_iter().map(Into::into));
functions::core::get_field().call(args).into()
}

#[pyfunction]
Expand Down
5 changes: 2 additions & 3 deletions examples/datafusion-ffi-example/src/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

use std::sync::Arc;

use datafusion_catalog::{TableFunctionImpl, TableProvider};
use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
use datafusion_common::error::Result as DataFusionResult;
use datafusion_expr::Expr;
use datafusion_ffi::udtf::FFI_TableFunction;
use datafusion_python_util::ffi_logical_codec_from_pycapsule;
use pyo3::types::PyCapsule;
Expand Down Expand Up @@ -59,7 +58,7 @@ impl MyTableFunction {
}

impl TableFunctionImpl for MyTableFunction {
fn call(&self, _args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
fn call_with_args(&self, _args: TableFunctionArgs) -> DataFusionResult<Arc<dyn TableProvider>> {
let provider = MyTableProvider::new(4, 3, 2).create_table()?;
Ok(Arc::new(provider))
}
Expand Down
58 changes: 54 additions & 4 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@

import pandas as pd
import polars as pl # type: ignore[import]
from _typeshed import CapsuleType as _PyCapsule

from datafusion.catalog import CatalogProvider, Table
from datafusion.common import DFSchema
from datafusion.expr import Expr, SortKey
from datafusion.plan import ExecutionPlan, LogicalPlan
from datafusion.user_defined import (
AggregateUDF,
LogicalExtensionCodecExportable,
PhysicalExtensionCodecExportable,
ScalarUDF,
TableFunction,
WindowUDF,
Expand Down Expand Up @@ -959,6 +962,45 @@ def register_record_batches(
"""
self.ctx.register_record_batches(name, partitions)

def read_batch(self, batch: pa.RecordBatch) -> DataFrame:
Comment thread
timsaucer marked this conversation as resolved.
"""Return a :py:class:`~datafusion.DataFrame` reading a single batch.

Convenience wrapper around :py:meth:`read_batches` for the single-batch
case. Unlike :py:meth:`register_batch`, this does not register the
batch as a named table; it returns an anonymous
:py:class:`~datafusion.DataFrame` directly.

Args:
batch: Record batch to wrap as a DataFrame.

Examples:
>>> ctx = dfn.SessionContext()
>>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
>>> ctx.read_batch(batch).to_pydict()
{'a': [1, 2, 3]}
"""
return self.read_batches([batch])

def read_batches(self, batches: list[pa.RecordBatch]) -> DataFrame:
"""Return a :py:class:`~datafusion.DataFrame` reading the given batches.

All batches must share the same schema. Unlike
:py:meth:`register_record_batches`, this does not register the batches
as a named table; it returns an anonymous
:py:class:`~datafusion.DataFrame` directly.

Args:
batches: Record batches to wrap as a DataFrame.

Examples:
>>> ctx = dfn.SessionContext()
>>> b1 = pa.RecordBatch.from_pydict({"a": [1, 2]})
>>> b2 = pa.RecordBatch.from_pydict({"a": [3, 4]})
>>> ctx.read_batches([b1, b2]).to_pydict()
{'a': [1, 2, 3, 4]}
"""
return DataFrame(self.ctx.read_batches(batches))

def register_parquet(
self,
name: str,
Expand Down Expand Up @@ -1744,11 +1786,15 @@ def __datafusion_logical_extension_codec__(self) -> Any:
"""Access the PyCapsule FFI_LogicalExtensionCodec."""
return self.ctx.__datafusion_logical_extension_codec__()

def with_logical_extension_codec(self, codec: Any) -> SessionContext:
def with_logical_extension_codec(
self, codec: LogicalExtensionCodecExportable | _PyCapsule
) -> SessionContext:
"""Create a new session context with specified codec.

This only supports codecs that have been implemented using the
FFI interface.
FFI interface. ``codec`` must either be a raw ``FFI_LogicalExtensionCodec``
``PyCapsule`` or an object exposing
``__datafusion_logical_extension_codec__``.
Comment thread
timsaucer marked this conversation as resolved.
Outdated
"""
new_internal = self.ctx.with_logical_extension_codec(codec)
new = SessionContext.__new__(SessionContext)
Expand All @@ -1759,11 +1805,15 @@ def __datafusion_physical_extension_codec__(self) -> Any:
"""Access the PyCapsule FFI_PhysicalExtensionCodec."""
return self.ctx.__datafusion_physical_extension_codec__()

def with_physical_extension_codec(self, codec: Any) -> SessionContext:
def with_physical_extension_codec(
self, codec: PhysicalExtensionCodecExportable | _PyCapsule
) -> SessionContext:
"""Create a new session context with the specified physical codec.

This only supports codecs that have been implemented using the
FFI interface.
FFI interface. ``codec`` must either be a raw
Comment thread
timsaucer marked this conversation as resolved.
Outdated
``FFI_PhysicalExtensionCodec`` ``PyCapsule`` or an object exposing
``__datafusion_physical_extension_codec__``.
"""
new_internal = self.ctx.with_physical_extension_codec(codec)
new = SessionContext.__new__(SessionContext)
Expand Down
42 changes: 34 additions & 8 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2727,14 +2727,24 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
return Expr(f.arrow_metadata(expr.expr, key.expr))


def get_field(expr: Expr, name: Expr | str) -> Expr:
"""Extracts a field from a struct or map by name.
def get_field(expr: Expr, *names: Expr | str) -> Expr:
Comment thread
timsaucer marked this conversation as resolved.
"""Extracts a (possibly nested) field from a struct or map by name.

When the field name is a static string, the bracket operator
``expr["field"]`` is a convenient shorthand. Use ``get_field``
when the field name is a dynamic expression.
Pass one name for a single-level lookup, or several names to walk a path
of nested struct/map fields in a single ``get_field`` call. For a single
static-string name, ``expr["field"]`` is a convenient shorthand; use
``get_field`` when the field name is a dynamic
:py:class:`~datafusion.expr.Expr` or when traversing multiple levels at
once.

Args:
expr: The struct or map expression to read from.
*names: One or more field names (``str``) or expressions
(:py:class:`~datafusion.expr.Expr`).

Examples:
Single-level lookup:

>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1], "b": [2]})
>>> df = df.with_column(
Expand All @@ -2756,10 +2766,26 @@ def get_field(expr: Expr, name: Expr | str) -> Expr:
... )
>>> result.collect_column("x_val")[0].as_py()
1

Multi-level lookup:

>>> df = df.with_column(
... "outer",
... dfn.functions.named_struct([("inner", dfn.col("s"))]),
Comment thread
timsaucer marked this conversation as resolved.
Outdated
... )
>>> result = df.select(
... dfn.functions.get_field(
... dfn.col("outer"), "inner", "x"
... ).alias("x_val")
... )
>>> result.collect_column("x_val")[0].as_py()
1
"""
if isinstance(name, str):
name = Expr.string_literal(name)
return Expr(f.get_field(expr.expr, name.expr))
if not names:
msg = "get_field requires at least one field name"
raise ValueError(msg)
resolved = [Expr.string_literal(n) if isinstance(n, str) else n for n in names]
return Expr(f.get_field(expr.expr, [n.expr for n in resolved]))


def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
Expand Down
12 changes: 12 additions & 0 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
return value.__class__.__name__ == "PyCapsule"


class LogicalExtensionCodecExportable(Protocol):
"""Type hint for objects exposing ``__datafusion_logical_extension_codec__``."""

def __datafusion_logical_extension_codec__(self) -> object: ... # noqa: D105


class PhysicalExtensionCodecExportable(Protocol):
"""Type hint for objects exposing ``__datafusion_physical_extension_codec__``."""

def __datafusion_physical_extension_codec__(self) -> object: ... # noqa: D105


class ScalarUDF:
"""Class for performing scalar user-defined functions (UDF).

Expand Down
15 changes: 15 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,21 @@ def test_register_batch_empty(ctx):
assert result[0].num_rows == 0


def test_read_batch_returns_dataframe(ctx):
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
df = ctx.read_batch(batch)
assert df.to_pydict() == {"a": [1, 2, 3], "b": [4, 5, 6]}
# read_batch should not register a named table.
assert ctx.catalog().schema().names() == set()


def test_read_batches_concatenates(ctx):
b1 = pa.RecordBatch.from_pydict({"a": [1, 2]})
b2 = pa.RecordBatch.from_pydict({"a": [3, 4]})
df = ctx.read_batches([b1, b2])
assert df.to_pydict() == {"a": [1, 2, 3, 4]}


def test_create_sql_options():
SQLOptions()

Expand Down
31 changes: 31 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,37 @@ def test_get_field(df):
assert result.column(1) == pa.array([4, 5, 6])


def test_get_field_path(df):
df = df.with_column(
"outer",
f.named_struct(
[
(
"inner",
f.named_struct(
[
("x", column("a")),
("y", column("b")),
]
),
),
]
),
)
result = df.select(
f.get_field(column("outer"), "inner", "x").alias("x_val"),
f.get_field(column("outer"), "inner", "y").alias("y_val"),
).collect()[0]

assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
assert result.column(1) == pa.array([4, 5, 6])


def test_get_field_requires_a_name():
with pytest.raises(ValueError, match="at least one field name"):
f.get_field(column("s"))


def test_arrow_metadata():
ctx = SessionContext()
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})
Expand Down