Skip to content
8 changes: 7 additions & 1 deletion .ai/skills/check-upstream/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all"
- Python API: `python/datafusion/functions.py` — each function wraps a call to `datafusion._internal.functions`
- Rust bindings: `crates/core/src/functions.rs` — `#[pyfunction]` definitions registered via `init_module()`

**Evaluated and not requiring separate Python exposure:**
- `get_field_path` — already covered by `get_field(expr, *names)`, which takes a
variadic field path and dispatches to the same underlying
`functions::core::get_field` UDF as the upstream `get_field_path` helper.

**How to check:**
1. Fetch the upstream scalar function documentation page
2. Compare against functions listed in `python/datafusion/functions.py` (check the `__all__` list and function definitions)
3. A function is covered if it exists in the Python API — it does NOT need a dedicated Rust `#[pyfunction]`. Many functions are aliases that reuse another function's Rust binding.
4. Only report functions that are missing from the Python `__all__` list / function definitions
4. Check against the "evaluated and not requiring exposure" list before flagging as a gap
5. Only report functions that are missing from the Python `__all__` list / function definitions

### 2. Aggregate Functions

Expand Down
42 changes: 41 additions & 1 deletion crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::datasource::{MemTable, TableProvider};
use datafusion::execution::TaskContextProvider;
use datafusion::execution::context::{
DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
};
Expand All @@ -44,6 +43,7 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, Unboun
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::execution::{FunctionRegistry, TaskContextProvider};
use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, JsonReadOptions, ParquetReadOptions,
};
Expand Down Expand Up @@ -847,6 +847,13 @@ impl PySessionContext {
Ok(())
}

pub fn read_batches(
&self,
batches: PyArrowType<Vec<RecordBatch>>,
) -> PyDataFusionResult<PyDataFrame> {
Ok(PyDataFrame::new(self.ctx.read_batches(batches.0)?))
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (name, path, table_partition_cols=vec![],
parquet_pruning=true,
Expand Down Expand Up @@ -1065,6 +1072,39 @@ impl PySessionContext {
self.ctx.deregister_udwf(name);
}

pub fn udf(&self, name: &str) -> PyDataFusionResult<PyScalarUDF> {
let function = (*self.ctx.udf(name)?).clone();
Ok(PyScalarUDF { function })
}

pub fn udaf(&self, name: &str) -> PyDataFusionResult<PyAggregateUDF> {
let function = (*self.ctx.udaf(name)?).clone();
Ok(PyAggregateUDF { function })
}

pub fn udwf(&self, name: &str) -> PyDataFusionResult<PyWindowUDF> {
let function = (*self.ctx.udwf(name)?).clone();
Ok(PyWindowUDF { function })
}

pub fn udfs(&self) -> Vec<String> {
let mut names: Vec<String> = self.ctx.udfs().into_iter().collect();
names.sort();
names
}

pub fn udafs(&self) -> Vec<String> {
let mut names: Vec<String> = self.ctx.udafs().into_iter().collect();
names.sort();
names
}

pub fn udwfs(&self) -> Vec<String> {
let mut names: Vec<String> = self.ctx.udwfs().into_iter().collect();
names.sort();
names
}

#[pyo3(signature = (name="datafusion"))]
pub fn catalog(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
Expand Down
8 changes: 4 additions & 4 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,10 @@ expr_fn!(union_tag, arg1);
expr_fn!(random);

#[pyfunction]
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
functions::core::get_field()
.call(vec![expr.into(), name.into()])
.into()
fn get_field(expr: PyExpr, names: Vec<PyExpr>) -> PyExpr {
let mut args = vec![expr.into()];
args.extend(names.into_iter().map(Into::into));
functions::core::get_field().call(args).into()
}

#[pyfunction]
Expand Down
5 changes: 2 additions & 3 deletions examples/datafusion-ffi-example/src/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

use std::sync::Arc;

use datafusion_catalog::{TableFunctionImpl, TableProvider};
use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
use datafusion_common::error::Result as DataFusionResult;
use datafusion_expr::Expr;
use datafusion_ffi::udtf::FFI_TableFunction;
use datafusion_python_util::ffi_logical_codec_from_pycapsule;
use pyo3::types::PyCapsule;
Expand Down Expand Up @@ -59,7 +58,7 @@ impl MyTableFunction {
}

impl TableFunctionImpl for MyTableFunction {
fn call(&self, _args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
fn call_with_args(&self, _args: TableFunctionArgs) -> DataFusionResult<Arc<dyn TableProvider>> {
let provider = MyTableProvider::new(4, 3, 2).create_table()?;
Ok(Arc::new(provider))
}
Expand Down
117 changes: 113 additions & 4 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@

import pandas as pd
import polars as pl # type: ignore[import]
from _typeshed import CapsuleType as _PyCapsule

from datafusion.catalog import CatalogProvider, Table
from datafusion.common import DFSchema
from datafusion.expr import Expr, SortKey
from datafusion.plan import ExecutionPlan, LogicalPlan
from datafusion.user_defined import (
AggregateUDF,
LogicalExtensionCodecExportable,
PhysicalExtensionCodecExportable,
ScalarUDF,
TableFunction,
WindowUDF,
Expand Down Expand Up @@ -959,6 +962,45 @@ def register_record_batches(
"""
self.ctx.register_record_batches(name, partitions)

def read_batch(self, batch: pa.RecordBatch) -> DataFrame:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider it more pythonic for read_batches to accept RecordBatch | Iterable[RecordBatches]

"""Return a :py:class:`~datafusion.DataFrame` reading a single batch.

Convenience wrapper around :py:meth:`read_batches` for the single-batch
case. Unlike :py:meth:`register_batch`, this does not register the
batch as a named table; it returns an anonymous
:py:class:`~datafusion.DataFrame` directly.

Args:
batch: Record batch to wrap as a DataFrame.

Examples:
>>> ctx = dfn.SessionContext()
>>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
>>> ctx.read_batch(batch).to_pydict()
{'a': [1, 2, 3]}
"""
return self.read_batches([batch])

def read_batches(self, batches: list[pa.RecordBatch]) -> DataFrame:
"""Return a :py:class:`~datafusion.DataFrame` reading the given batches.

All batches must share the same schema. Unlike
:py:meth:`register_record_batches`, this does not register the batches
as a named table; it returns an anonymous
:py:class:`~datafusion.DataFrame` directly.

Args:
batches: Record batches to wrap as a DataFrame.

Examples:
>>> ctx = dfn.SessionContext()
>>> b1 = pa.RecordBatch.from_pydict({"a": [1, 2]})
>>> b2 = pa.RecordBatch.from_pydict({"a": [3, 4]})
>>> ctx.read_batches([b1, b2]).to_pydict()
{'a': [1, 2, 3, 4]}
"""
return DataFrame(self.ctx.read_batches(batches))

def register_parquet(
self,
name: str,
Expand Down Expand Up @@ -1268,6 +1310,65 @@ def deregister_udwf(self, name: str) -> None:
"""
self.ctx.deregister_udwf(name)

def udf(self, name: str) -> ScalarUDF:
"""Look up a registered scalar UDF by name.

Args:
name: Name of the registered scalar UDF.

Raises:
Exception: If no scalar UDF is registered under ``name``.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall if this is the convention across the code base but on quick look I'd expect to just return ScalarUDF | None

"""
from datafusion.user_defined import ScalarUDF as _ScalarUDF # noqa: PLC0415

wrapper = _ScalarUDF.__new__(_ScalarUDF)
wrapper._udf = self.ctx.udf(name)
return wrapper

def udaf(self, name: str) -> AggregateUDF:
"""Look up a registered aggregate UDF by name.

Args:
name: Name of the registered aggregate UDF.

Raises:
Exception: If no aggregate UDF is registered under ``name``.
"""
from datafusion.user_defined import ( # noqa: PLC0415
AggregateUDF as _AggregateUDF,
)

wrapper = _AggregateUDF.__new__(_AggregateUDF)
wrapper._udaf = self.ctx.udaf(name)
return wrapper

def udwf(self, name: str) -> WindowUDF:
"""Look up a registered window UDF by name.

Args:
name: Name of the registered window UDF.

Raises:
Exception: If no window UDF is registered under ``name``.
"""
from datafusion.user_defined import WindowUDF as _WindowUDF # noqa: PLC0415

wrapper = _WindowUDF.__new__(_WindowUDF)
wrapper._udwf = self.ctx.udwf(name)
return wrapper

def udfs(self) -> list[str]:
"""Return the sorted names of all registered scalar UDFs."""
return self.ctx.udfs()

def udafs(self) -> list[str]:
"""Return the sorted names of all registered aggregate UDFs."""
return self.ctx.udafs()

def udwfs(self) -> list[str]:
"""Return the sorted names of all registered window UDFs."""
return self.ctx.udwfs()

def catalog(self, name: str = "datafusion") -> Catalog:
"""Retrieve a catalog by name."""
return Catalog(self.ctx.catalog(name))
Expand Down Expand Up @@ -1744,11 +1845,15 @@ def __datafusion_logical_extension_codec__(self) -> Any:
"""Access the PyCapsule FFI_LogicalExtensionCodec."""
return self.ctx.__datafusion_logical_extension_codec__()

def with_logical_extension_codec(self, codec: Any) -> SessionContext:
def with_logical_extension_codec(
self, codec: LogicalExtensionCodecExportable | _PyCapsule
) -> SessionContext:
"""Create a new session context with specified codec.

This only supports codecs that have been implemented using the
FFI interface.
FFI interface. ``codec`` must either be a raw ``FFI_LogicalExtensionCodec``
``PyCapsule`` or an object exposing
``__datafusion_logical_extension_codec__``.
Comment on lines +1756 to +1758
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this addition? Isn't this redundant with the typing?

"""
new_internal = self.ctx.with_logical_extension_codec(codec)
new = SessionContext.__new__(SessionContext)
Expand All @@ -1759,11 +1864,15 @@ def __datafusion_physical_extension_codec__(self) -> Any:
"""Access the PyCapsule FFI_PhysicalExtensionCodec."""
return self.ctx.__datafusion_physical_extension_codec__()

def with_physical_extension_codec(self, codec: Any) -> SessionContext:
def with_physical_extension_codec(
self, codec: PhysicalExtensionCodecExportable | _PyCapsule
) -> SessionContext:
"""Create a new session context with the specified physical codec.

This only supports codecs that have been implemented using the
FFI interface.
FFI interface. ``codec`` must either be a raw
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on shadowing type hint

``FFI_PhysicalExtensionCodec`` ``PyCapsule`` or an object exposing
``__datafusion_physical_extension_codec__``.
"""
new_internal = self.ctx.with_physical_extension_codec(codec)
new = SessionContext.__new__(SessionContext)
Expand Down
42 changes: 34 additions & 8 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2727,14 +2727,24 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
return Expr(f.arrow_metadata(expr.expr, key.expr))


def get_field(expr: Expr, name: Expr | str) -> Expr:
"""Extracts a field from a struct or map by name.
def get_field(expr: Expr, *names: Expr | str) -> Expr:
Comment thread
timsaucer marked this conversation as resolved.
"""Extracts a (possibly nested) field from a struct or map by name.

When the field name is a static string, the bracket operator
``expr["field"]`` is a convenient shorthand. Use ``get_field``
when the field name is a dynamic expression.
Pass one name for a single-level lookup, or several names to walk a path
of nested struct/map fields in a single ``get_field`` call. For a single
static-string name, ``expr["field"]`` is a convenient shorthand; use
``get_field`` when the field name is a dynamic
:py:class:`~datafusion.expr.Expr` or when traversing multiple levels at
once.

Args:
expr: The struct or map expression to read from.
*names: One or more field names (``str``) or expressions
(:py:class:`~datafusion.expr.Expr`).

Examples:
Single-level lookup:

>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1], "b": [2]})
>>> df = df.with_column(
Expand All @@ -2756,10 +2766,26 @@ def get_field(expr: Expr, name: Expr | str) -> Expr:
... )
>>> result.collect_column("x_val")[0].as_py()
1

Multi-level lookup:

>>> df = df.with_column(
... "outer",
... dfn.functions.named_struct([("inner", dfn.col("s"))]),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Not required here but the doctest namespace already imports functions a F which would make this addition less verbose.

... )
>>> result = df.select(
... dfn.functions.get_field(
... dfn.col("outer"), "inner", "x"
... ).alias("x_val")
... )
>>> result.collect_column("x_val")[0].as_py()
1
"""
if isinstance(name, str):
name = Expr.string_literal(name)
return Expr(f.get_field(expr.expr, name.expr))
if not names:
msg = "get_field requires at least one field name"
raise ValueError(msg)
resolved = [Expr.string_literal(n) if isinstance(n, str) else n for n in names]
return Expr(f.get_field(expr.expr, [n.expr for n in resolved]))


def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
Expand Down
12 changes: 12 additions & 0 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
return value.__class__.__name__ == "PyCapsule"


class LogicalExtensionCodecExportable(Protocol):
"""Type hint for objects exposing ``__datafusion_logical_extension_codec__``."""

def __datafusion_logical_extension_codec__(self) -> object: ... # noqa: D105


class PhysicalExtensionCodecExportable(Protocol):
"""Type hint for objects exposing ``__datafusion_physical_extension_codec__``."""

def __datafusion_physical_extension_codec__(self) -> object: ... # noqa: D105


class ScalarUDF:
"""Class for performing scalar user-defined functions (UDF).

Expand Down
15 changes: 15 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,21 @@ def test_register_batch_empty(ctx):
assert result[0].num_rows == 0


def test_read_batch_returns_dataframe(ctx):
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
df = ctx.read_batch(batch)
assert df.to_pydict() == {"a": [1, 2, 3], "b": [4, 5, 6]}
# read_batch should not register a named table.
assert ctx.catalog().schema().names() == set()


def test_read_batches_concatenates(ctx):
b1 = pa.RecordBatch.from_pydict({"a": [1, 2]})
b2 = pa.RecordBatch.from_pydict({"a": [3, 4]})
df = ctx.read_batches([b1, b2])
assert df.to_pydict() == {"a": [1, 2, 3, 4]}


def test_create_sql_options():
SQLOptions()

Expand Down
Loading