Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
428 changes: 394 additions & 34 deletions crates/core/src/codec.rs

Large diffs are not rendered by default.

82 changes: 74 additions & 8 deletions crates/core/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::expr::PyExpr;
/// This struct holds the Python written function that is a
/// ScalarUDF.
#[derive(Debug)]
struct PythonFunctionScalarUDF {
pub(crate) struct PythonFunctionScalarUDF {
name: String,
func: Py<PyAny>,
signature: Signature,
Expand All @@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF {
return_field: Arc::new(return_field),
}
}

/// Stored Python callable. Consumed by the codec to cloudpickle
/// the function body across process boundaries.
pub(crate) fn func(&self) -> &Py<PyAny> {
&self.func
}

pub(crate) fn return_field(&self) -> &FieldRef {
&self.return_field
}

/// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted
/// by the codec. Inputs collapse to `Vec<DataType>` because
/// `Signature::exact` cannot carry per-input nullability or
/// metadata — the encoder is free to discard that side of the
/// schema. `return_field` is kept as a `Field` so the post-decode
/// nullability and metadata match the sender's instance.
pub(crate) fn from_parts(
name: String,
func: Py<PyAny>,
input_types: Vec<DataType>,
return_field: Field,
volatility: Volatility,
) -> Self {
Self {
name,
func,
signature: Signature::exact(input_types, volatility),
return_field: Arc::new(return_field),
}
}
}

impl Eq for PythonFunctionScalarUDF {}
Expand All @@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF {
self.name == other.name
&& self.signature == other.signature
&& self.return_field == other.return_field
&& Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
// Identical pointers ⇒ same Python object. Most equality
// checks compare `Arc`-shared clones of the same UDF
// (e.g. expression rewriting), so the pointer match short-
// circuits before touching the GIL.
&& (self.func.as_ptr() == other.func.as_ptr()
|| Python::attach(|py| {
// Rust's `PartialEq` cannot return `Result`, so we
// have to pick a side when Python `__eq__` raises.
// `false` is the conservative choice — better to
// report two UDFs as distinct than to wrongly
// merge them — but the silent miss can still
// surface as expression-dedup or cache-lookup
// anomalies. Log at `debug` so the failure is
// observable without flooding production logs.
// FIXME: revisit if upstream `ScalarUDFImpl`
// exposes a fallible `PartialEq`.
self.func
.bind(py)
.eq(other.func.bind(py))
.unwrap_or_else(|e| {
log::debug!(
target: "datafusion_python::udf",
"PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}",
self.name,
);
false
})
}))
}
}

impl Hash for PythonFunctionScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
// Hash only the identifying header (name + signature + return
// field). Skipping `func` is intentional: the Rust `Hash`
// contract requires `a == b ⇒ hash(a) == hash(b)`, not the
// converse, so a coarser hash is sound — `PartialEq` still
// disambiguates two UDFs with the same header but distinct
// callables. Falling back to a sentinel on `py_hash` failure
// (as a prior revision did) silently mapped every unhashable
// closure to the same bucket; that is the worst case for a
// hashmap and is what this rewrite avoids.
self.name.hash(state);
self.signature.hash(state);
self.return_field.hash(state);

Python::attach(|py| {
let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects

state.write_isize(py_hash);
});
}
}

Expand Down Expand Up @@ -220,4 +281,9 @@ impl PyScalarUDF {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("ScalarUDF({})", self.function.name()))
}

#[getter]
fn name(&self) -> &str {
self.function.name()
}
}
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ classifiers = [
"Programming Language :: Rust",
]
dependencies = [
# cloudpickle is invoked by the Rust-side PythonLogicalCodec /
# PythonPhysicalCodec via pyo3 to serialize Python UDF callables —
# scalar, aggregate, and window — into the proto wire format.
# Lazy-imported on the encode / decode hot paths (and cached after
# the first import), so users who never serialize a plan or
# expression incur no runtime cost beyond the install footprint.
"cloudpickle>=2.0",
"pyarrow>=16.0.0;python_version<'3.14'",
"pyarrow>=22.0.0;python_version>='3.14'",
"typing-extensions;python_version<'3.13'",
Expand Down
3 changes: 2 additions & 1 deletion python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import importlib_metadata # type: ignore[import]

# Public submodules
from . import functions, object_store, substrait, unparser
from . import functions, ipc, object_store, substrait, unparser

# The following imports are okay to remain as opaque to the user.
from ._internal import Config
Expand Down Expand Up @@ -142,6 +142,7 @@
"configure_formatter",
"expr",
"functions",
"ipc",
"lit",
"literal",
"object_store",
Expand Down
58 changes: 47 additions & 11 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,23 +434,59 @@ def variant_name(self) -> str:
return self.expr.variant_name()

def to_bytes(self, ctx: SessionContext | None = None) -> bytes:
"""Serialize this expression to protobuf bytes.
"""Serialize this expression to bytes for shipping to another process.

When ``ctx`` is supplied, encoding routes through the session's
installed :class:`LogicalExtensionCodec`. Without ``ctx`` a
default codec is used.
Use this — or :func:`pickle.dumps` — to send an expression to a
worker process for distributed evaluation.

When ``ctx`` is supplied, encoding routes through that session's
installed :class:`LogicalExtensionCodec`. When ``ctx`` is
``None``, the default codec is used.

Built-in functions and Python scalar UDFs travel inside the
returned bytes; the worker does not need to pre-register them.
UDFs imported via the FFI capsule protocol travel by name only
and must be registered on the worker.
"""
ctx_arg = ctx.ctx if ctx is not None else None
return self.expr.to_bytes(ctx_arg)

@staticmethod
def from_bytes(ctx: SessionContext, data: bytes) -> Expr:
"""Decode an expression from serialized protobuf bytes.

``ctx`` provides the function registry for resolving UDF
references and the logical codec for in-band Python payloads.
@classmethod
def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr:
"""Reconstruct an expression from serialized bytes.

Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`.
``ctx`` is the :class:`SessionContext` used to resolve any
function references that travel by name (e.g. FFI UDFs). When
``ctx`` is ``None`` the worker context installed via
:func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker
context is installed, the global :class:`SessionContext` is used
(sufficient for built-ins and Python scalar UDFs, plus any UDFs
registered on the global context).
"""
from datafusion.ipc import _resolve_ctx

resolved = _resolve_ctx(ctx)
return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf))

def __reduce__(self) -> tuple:
"""Pickle protocol hook.

Lets expressions be shipped to worker processes via
:func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions
and Python scalar UDFs travel inside the pickle bytes; only
FFI-capsule UDFs require pre-registration on the worker. The
worker's :class:`SessionContext` for resolving those references
is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling
back to the global :class:`SessionContext` if none has been
installed on the worker.
"""
return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data))
return (Expr._reconstruct, (self.to_bytes(),))

@classmethod
def _reconstruct(cls, proto_bytes: bytes) -> Expr:
"""Internal entry point used by :meth:`__reduce__` on unpickle."""
return cls.from_bytes(proto_bytes)

def __richcmp__(self, other: Expr, op: int) -> Expr:
"""Comparison operator."""
Expand Down
113 changes: 113 additions & 0 deletions python/datafusion/ipc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Worker-side setup for distributing DataFusion expressions.

When a :class:`Expr` is shipped to a worker process (e.g. through
:func:`multiprocessing.Pool` or a Ray actor), the worker reconstructs the
expression against a :class:`SessionContext`. If the expression references
UDFs imported via the FFI capsule protocol — or any UDF the worker would
otherwise resolve from its registered functions rather than from inside
the shipped expression — install a configured :class:`SessionContext`
once per worker:

.. code-block:: python

from datafusion import SessionContext
from datafusion.ipc import set_worker_ctx

def init_worker():
ctx = SessionContext()
ctx.register_udaf(my_ffi_aggregate)
set_worker_ctx(ctx)

Built-in functions and Python scalar UDFs travel inside the shipped
expression itself and do not need pre-registration on the worker.
"""

from __future__ import annotations

import threading
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from datafusion.context import SessionContext


__all__ = [
"clear_worker_ctx",
"get_worker_ctx",
"set_worker_ctx",
]


_local = threading.local()


def set_worker_ctx(ctx: SessionContext) -> None:
"""Install this worker's :class:`SessionContext` for shipped expressions.

Call once per worker — typically from a ``multiprocessing.Pool``
initializer or a Ray actor ``__init__``. Idempotent: overwrites any
previous value. Stored in a thread-local slot, so each thread within a
worker may install its own context independently.
"""
_local.ctx = ctx


def clear_worker_ctx() -> None:
"""Remove this worker's installed :class:`SessionContext`.

After clearing, expressions reconstructed in this worker fall back to
the global :class:`SessionContext` — adequate for built-ins and Python
scalar UDFs, but anything imported via the FFI capsule protocol must
be registered on the global context to resolve.
"""
if hasattr(_local, "ctx"):
del _local.ctx


def get_worker_ctx() -> SessionContext | None:
"""Return this worker's installed :class:`SessionContext`, or ``None``."""
return getattr(_local, "ctx", None)


def _resolve_ctx(
explicit_ctx: SessionContext | None = None,
) -> SessionContext:
"""Resolve a context for Expr reconstruction.

Priority: explicit argument > worker context > global context.
Falling back to the global :class:`SessionContext` (instead of a
freshly constructed one) preserves any registrations the user has
installed on it.
"""
if explicit_ctx is not None:
return explicit_ctx
worker = get_worker_ctx()
if worker is not None:
return worker
# Lazy import: `datafusion/__init__.py` imports `datafusion.ipc`
# before `datafusion.context`, so a module-top import would force
# `datafusion.context` to load mid-init of `datafusion.ipc`. The
# cycle is benign today (context.py only pulls expr.py at module
# scope, neither pulls ipc.py back), but a single new import in
# context.py's transitive deps could turn it into a real cycle.
# Deferring keeps `datafusion.ipc` import-order-independent.
from datafusion.context import SessionContext # noqa: PLC0415

return SessionContext.global_ctx()
10 changes: 10 additions & 0 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ def __init__(
name, func, input_fields, return_field, str(volatility)
)

@property
def name(self) -> str:
"""Return the registered name of this UDF.

For UDFs imported via the FFI capsule protocol, this is the
name the capsule itself reports — not the ``name`` argument
passed to the constructor (which is ignored on the FFI path).
"""
return self._udf.name

def __repr__(self) -> str:
"""Print a string representation of the Scalar UDF."""
return self._udf.__repr__()
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def test_expr_to_bytes_roundtrip(ctx: SessionContext) -> None:

original = col("a") + lit(1)
blob = original.to_bytes(ctx)
restored = Expr.from_bytes(ctx, blob)
restored = Expr.from_bytes(blob, ctx=ctx)

# Canonical name preserves the structure of the expression even
# though the underlying PyExpr instances are different.
Expand All @@ -1201,6 +1201,6 @@ def test_expr_to_bytes_no_ctx_default_codec() -> None:
fresh = SessionContext()
original = col("a") * lit(2)
blob = original.to_bytes() # encode side: default codec
restored = Expr.from_bytes(fresh, blob)
restored = Expr.from_bytes(blob, ctx=fresh)

assert restored.canonical_name() == original.canonical_name()
Loading
Loading