diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index 8b29bf7d..cc403723 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -115,6 +115,23 @@ def test_attr_expr_method_delegation() -> None: assert result.to_list() == expected.to_list() +def test_attr_expr_struct_field_method_delegation() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 3}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + expr = NodeAttr("s").struct.field("x") + result = expr.evaluate(df) + assert isinstance(expr, NodeAttr) + assert result.to_list() == [1, 2, 3] + + +def test_attr_comparison_struct_field() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 1}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + comp = NodeAttr("s").struct.field("x") == 1 + result = comp.to_attr().evaluate(df) + assert comp.column == "s" + assert comp.attr.field_path == ("x",) + assert result.to_list() == [True, False, True] + + def test_attr_expr_complex_expression() -> None: df = pl.DataFrame({"iou": [0.5, 0.7, 0.9], "distance": [10, 20, 30]}) expr = (1 - Attr("iou")) * Attr("distance") diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 60f82db8..be223211 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -129,7 +129,7 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr raise ValueError(f"Comparison operators are not supported for multiple columns. Found {columns}.") self.attr = attr - self.column = columns[0] + self.column = attr.root_column if attr.root_column is not None else columns[0] self.op = op # casting numpy scalars to python scalars @@ -144,14 +144,18 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr self.other = other def __repr__(self) -> str: - return f"{type(self.attr).__name__}({self.column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" + if self.attr.field_path: + column = ".".join([str(self.column), *self.attr.field_path]) + else: + column = str(self.column) + return f"{type(self.attr).__name__}({column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" def to_attr(self) -> "Attr": """ Transform the comparison back to an [Attr][tracksdata.attrs.Attr] object. This is useful for evaluating the expression on a DataFrame. """ - return Attr(self.op(pl.col(self.column), self.other)) + return Attr(self.op(self.attr.expr, self.other)) def __getattr__(self, attr: str) -> Any: return getattr(self.to_attr(), attr) @@ -198,6 +202,31 @@ def __ge__(self, other: ExprInput) -> "Attr": ... def __rge__(self, other: ExprInput) -> "Attr": ... +class _StructNamespace: + """Wrapper around polars struct namespace that preserves Attr semantics.""" + + def __init__(self, attr: "Attr") -> None: + self._attr = attr + self._namespace = attr.expr.struct + + def field(self, name: str) -> "Attr": + out = self._attr._wrap(self._namespace.field(name), preserve_field_path=True) + if isinstance(out, Attr): + out._append_field_path(name) + return out + + def __getattr__(self, name: str) -> Any: + namespace_attr = getattr(self._namespace, name) + if callable(namespace_attr): + + @functools.wraps(namespace_attr) + def _wrapped(*args, **kwargs): + return self._attr._wrap(namespace_attr(*args, **kwargs)) + + return _wrapped + return namespace_attr + + class Attr: """ A class to compose an attribute expression for attribute filtering or value evaluation. @@ -222,30 +251,43 @@ class Attr: def __init__(self, value: ExprInput) -> None: self._inf_exprs = [] # expressions multiplied by +inf self._neg_inf_exprs = [] # expressions multiplied by -inf + # Path-tracking for backend filters: + # - root_column: top-level column used to store the value. + # - field_path: nested struct path from that root column. + self._root_column: str | None = None + self._field_path: tuple[str, ...] = () if isinstance(value, str): self.expr = pl.col(value) + self._root_column = value elif isinstance(value, Attr): self.expr = value.expr # Copy infinity tracking from the other AttrExpr self._inf_exprs = value.inf_exprs self._neg_inf_exprs = value.neg_inf_exprs + self._root_column = value.root_column + self._field_path = value.field_path elif isinstance(value, AttrComparison): attr = value.to_attr() self.expr = attr.expr self._inf_exprs = attr.inf_exprs self._neg_inf_exprs = attr.neg_inf_exprs + self._root_column = attr.root_column + self._field_path = attr.field_path elif isinstance(value, Expr): self.expr = value else: self.expr = pl.lit(value) - def _wrap(self, expr: ExprInput) -> Union["Attr", Any]: + def _wrap(self, expr: ExprInput, *, preserve_field_path: bool = False) -> Union["Attr", Any]: if isinstance(expr, Expr): - result = Attr(expr) + result = type(self)(expr) # Propagate infinity tracking result._inf_exprs = self._inf_exprs.copy() result._neg_inf_exprs = self._neg_inf_exprs.copy() + if preserve_field_path: + result._root_column = self._root_column + result._field_path = self._field_path return result return expr @@ -377,6 +419,33 @@ def evaluate(self, df: DataFrame) -> Series: def columns(self) -> list[str]: return list(dict.fromkeys(self.expr_columns + self.inf_columns + self.neg_inf_columns)) + @property + def root_column(self) -> str | None: + """ + Top-level column name from which this expression originates. + + Examples + -------- + `Attr("t").root_column == "t"` + `NodeAttr("measurements").struct.field("score").root_column == "measurements"` + """ + return self._root_column + + @property + def field_path(self) -> tuple[str, ...]: + """ + Nested struct-field path relative to [root_column][tracksdata.attrs.Attr.root_column]. + + Empty tuple means no nested access. + + Examples + -------- + `Attr("t").field_path == ()` + `NodeAttr("measurements").struct.field("score").field_path == ("score",)` + `NodeAttr("meta").struct.field("det").struct.field("conf").field_path == ("det", "conf")` + """ + return self._field_path + @property def inf_exprs(self) -> list["Attr"]: """Get the expressions multiplied by positive infinity.""" @@ -464,6 +533,9 @@ def __getattr__(self, attr: str) -> Any: if attr.startswith("_"): raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + if attr == "struct": + return _StructNamespace(self) + # To auto generate operator methods such as `.log()`` expr_attr = getattr(self.expr, attr) if callable(expr_attr): @@ -475,6 +547,12 @@ def _wrapped(*args, **kwargs): return _wrapped return expr_attr + def _append_field_path(self, field_name: str) -> None: + if self._root_column is None: + self._field_path = () + else: + self._field_path = (*self._field_path, field_name) + def __repr__(self) -> str: return f"Attr({self.expr})" @@ -733,4 +811,4 @@ def polars_reduce_attr_comps( # Return True for all rows by using the first column as a reference raise ValueError("No attribute comparisons provided.") - return pl.reduce(reduce_op, [attr_comp.op(df[str(attr_comp.column)], attr_comp.other) for attr_comp in attr_comps]) + return pl.reduce(reduce_op, [attr_comp.op(attr_comp.attr.expr, attr_comp.other) for attr_comp in attr_comps]) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index a7d55377..637356e2 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -88,9 +88,30 @@ def _create_filter_func( ) -> Callable[[dict[str, Any]], bool]: LOG.info(f"Creating filter function for {attr_comps}") + def _extract_field_path(value: Any, field_path: tuple[str, ...]) -> Any: + for field in field_path: + if value is None: + return None + + if isinstance(value, dict): + value = value.get(field, None) + continue + + try: + value = value[field] + except (KeyError, IndexError, TypeError): + try: + value = getattr(value, field) + except AttributeError: + return None + + return value + def _filter(attrs: dict[str, Any]) -> bool: for attr_op in attr_comps: value = attrs.get(attr_op.column, schema[attr_op.column].default_value) + if attr_op.attr.field_path: + value = _extract_field_path(value, attr_op.attr.field_path) if not attr_op.op(value, attr_op.other): return False return True diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index f0b9ab86..f268efbf 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -20,8 +20,11 @@ from tracksdata.utils._cache import cache_method from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns from tracksdata.utils._dtypes import ( + STRUCT_FIELD_SEP, AttrSchema, deserialize_attr_schema, + flatten_struct_dtype, + flatten_struct_value, polars_dtype_to_sqlalchemy_type, process_attr_key_args, serialize_attr_schema, @@ -56,6 +59,23 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +def _resolve_attr_filter_column( + table: type[DeclarativeBase], + attr_filter: AttrComparison, +) -> Any: + """Return the SQLAlchemy column expression for an AttrComparison. + + For struct field paths (e.g. ``NodeAttr("m").struct.field("score")``), the + field path is joined with ``STRUCT_FIELD_SEP`` to form the physical flat + column name (e.g. ``m__score``), which is a native SQL column. + """ + if not attr_filter.attr.field_path: + return getattr(table, str(attr_filter.column)) + + flat_col = STRUCT_FIELD_SEP.join([str(attr_filter.column), *attr_filter.attr.field_path]) + return getattr(table, flat_col) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -80,7 +100,13 @@ def _filter_query( """ LOG.info("Filter query:\n%s", attr_filters) query = query.filter( - *[attr_filter.op(getattr(table, str(attr_filter.column)), attr_filter.other) for attr_filter in attr_filters] + *[ + attr_filter.op( + _resolve_attr_filter_column(table, attr_filter), + attr_filter.other, + ) + for attr_filter in attr_filters + ] ) return query @@ -123,7 +149,11 @@ def __init__( if self._node_attr_comps: node_filtered = True # filtering nodes by attributes - self._node_query = _filter_query(self._node_query, self._graph.Node, self._node_attr_comps) + self._node_query = _filter_query( + self._node_query, + self._graph.Node, + self._node_attr_comps, + ) # if both node and edge attributes are filtered # we need to select subset of edges that belong to the filtered nodes @@ -139,17 +169,29 @@ def __init__( SourceNode, self._graph.Edge.source_id == SourceNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, SourceNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + SourceNode, + self._node_attr_comps, + ) if self._include_sources or include_none: self._edge_query = self._edge_query.join( TargetNode, self._graph.Edge.target_id == TargetNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, TargetNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + TargetNode, + self._node_attr_comps, + ) if self._edge_attr_comps: - self._edge_query = _filter_query(self._edge_query, self._graph.Edge, self._edge_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + self._graph.Edge, + self._edge_attr_comps, + ) # we haven't filtered the nodes by attributes # so we only return the nodes that are in the edges @@ -228,15 +270,15 @@ def node_attrs( nodes_attrs = nodes_attrs.select(attr_keys) nodes_attrs = unpickle_bytes_columns(nodes_attrs) - nodes_attrs = self._graph._cast_array_columns(self._graph.Node, nodes_attrs) + nodes_attrs = self._graph._cast_columns(self._graph.Node, nodes_attrs) if unpack: nodes_attrs = unpack_array_attrs(nodes_attrs) return nodes_attrs - @staticmethod def _query_from_attr_keys( + self, query: sa.Select, table: type[DeclarativeBase], attr_keys: list[str] | None = None, @@ -250,14 +292,23 @@ def _query_from_attr_keys( LOG.info("Query attr_keys: %s", attr_keys) + schemas = self._graph._attr_schemas_for_table(table) + flat_names: list[str] = [] + for key in attr_keys: + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct): + flat_names.extend(fc for fc, _ in flatten_struct_dtype(key, schema.dtype)) + else: + flat_names.append(key) + if isinstance(query, sa.CompoundSelect): union_query = query.alias("u") query = sa.select( - *[getattr(union_query.c, key) for key in attr_keys], + *[getattr(union_query.c, name) for name in flat_names], ) else: query = query.with_only_columns( - *[getattr(table, key) for key in attr_keys], + *[getattr(table, name) for name in flat_names], ) LOG.info("Query after attr_keys selection:\n%s", query) @@ -285,7 +336,7 @@ def edge_attrs(self, attr_keys: list[str] | None = None, unpack: bool = False) - ) edges_df = unpickle_bytes_columns(edges_df) - edges_df = self._graph._cast_array_columns(self._graph.Edge, edges_df) + edges_df = self._graph._cast_columns(self._graph.Edge, edges_df) if unpack: edges_df = unpack_array_attrs(edges_df) @@ -585,27 +636,24 @@ def _attr_schemas_from_metadata( {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} ) + # Compute the set of flat physical columns that belong to known struct schemas, + # so the legacy fallback below does not register them as independent logical keys. + known_flat_cols: set[str] = set() + for schema in schemas.values(): + if isinstance(schema.dtype, pl.Struct): + known_flat_cols.update(fc for fc, _ in flatten_struct_dtype(schema.key, schema.dtype)) + # Legacy databases may not have schema metadata for all columns. for column_name, column in table_class.__table__.columns.items(): - if column_name not in schemas: + if column_name not in schemas and column_name not in known_flat_cols: schemas[column_name] = AttrSchema( key=column_name, dtype=sqlalchemy_type_to_polars_dtype(column.type), ) - result = {} - - # return dictionary in preferred order - for source in ( - preferred_order, - table_class.__table__.columns.keys(), - schemas, - ): - for key in source: - if key in schemas: - result.setdefault(key, schemas[key]) - - return result + ordered_keys = [key for key in preferred_order if key in schemas] + ordered_keys.extend(key for key in schemas if key not in ordered_keys) + return {key: schemas[key] for key in ordered_keys} def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[str, AttrSchema]: if table_class.__tablename__ == self.Node.__tablename__: @@ -666,38 +714,92 @@ def _restore_pickled_column_types(self, table: sa.Table) -> None: column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: + """Return polars dtype overrides for physical columns in *table_class*. + + Flat struct leaf columns are included with their native leaf dtypes. + Pickled columns are excluded here and handled in a second pass by + ``_cast_array_columns``. + """ + overrides: SchemaDict = {} schemas = self._attr_schemas_for_table(table_class) + table_cols = table_class.__table__.columns - # Return schema overrides for columns safely represented in SQL. - # Pickled columns are unpickled and casted in a second pass. - return { - key: schema.dtype - for key, schema in schemas.items() - if ( - key in table_class.__table__.columns - and not self._is_pickled_sql_type(table_class.__table__.columns[key].type) - ) - } + for key, schema in schemas.items(): + if isinstance(schema.dtype, pl.Struct): + # Emit overrides for each leaf physical column. + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype): + if flat_col in table_cols and not self._is_pickled_sql_type(table_cols[flat_col].type): + overrides[flat_col] = leaf_dtype + elif key in table_cols and not self._is_pickled_sql_type(table_cols[key].type): + overrides[key] = schema.dtype + + return overrides - def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: + @staticmethod + def _build_struct_expr(key: str, dtype: pl.Struct) -> pl.Expr: + """Recursively build a ``pl.struct`` expression from flat leaf columns.""" + fields: list[pl.Expr] = [] + for field_name, field_dtype in dtype.to_schema().items(): + flat_col = f"{key}{STRUCT_FIELD_SEP}{field_name}" + if isinstance(field_dtype, pl.Struct): + fields.append(SQLGraph._build_struct_expr(flat_col, field_dtype).alias(field_name)) + else: + fields.append(pl.col(flat_col).alias(field_name)) + return pl.struct(fields) + + def _cast_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: + """Cast pickled columns to their target dtype and reconstruct struct columns.""" schemas = self._attr_schemas_for_table(table_class) + table_cols = table_class.__table__.columns casts: list[pl.Series] = [] + struct_keys: list[tuple[str, pl.Struct]] = [] + for key, schema in schemas.items(): - if key not in df.columns or key not in table_class.__table__.columns: + if isinstance(schema.dtype, pl.Struct): + # Cast any pickled flat leaf columns to their proper dtypes before + # reconstruction so Array/List fields have correct dtype. + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype): + if flat_col not in df.columns or flat_col not in table_cols: + continue + if not self._is_pickled_sql_type(table_cols[flat_col].type): + continue + try: + casts.append(pl.Series(flat_col, df[flat_col].to_list(), dtype=leaf_dtype)) + except Exception: + continue + struct_keys.append((key, schema.dtype)) continue - if not self._is_pickled_sql_type(table_class.__table__.columns[key].type): + if key not in df.columns or key not in table_cols: + continue + + if not self._is_pickled_sql_type(table_cols[key].type): continue try: casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) except Exception: - # Keep original dtype when values cannot be casted to the target schema. + # Keep original dtype when values cannot be cast to the target schema. continue if casts: df = df.with_columns(casts) + + # Reconstruct struct columns from their flat physical columns. + for key, dtype in struct_keys: + flat_cols = [fc for fc, _ in flatten_struct_dtype(key, dtype)] + present = [fc for fc in flat_cols if fc in df.columns] + if not present: + continue # struct was not part of this query; skip + missing = [fc for fc in flat_cols if fc not in df.columns] + if missing: + raise ValueError( + f"Struct attribute '{key}' is partially present in the DataFrame " + f"(missing: {missing}). Cannot reconstruct the struct column." + ) + df = df.with_columns(self._build_struct_expr(key, dtype).alias(key)).drop(flat_cols) + return df def _update_max_id_per_time(self) -> None: @@ -727,6 +829,24 @@ def filter( include_sources=include_sources, ) + def _flatten_attrs_for_write( + self, + attrs: dict[str, Any], + schemas: dict[str, AttrSchema], + ) -> dict[str, Any]: + """Expand struct-typed values into flat ``{leaf_col: value}`` pairs. + + Non-struct values are passed through unchanged. + """ + result: dict[str, Any] = {} + for key, value in attrs.items(): + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct) and isinstance(value, dict): + result.update(flatten_struct_value(key, value, schema.dtype)) + else: + result[key] = value + return result + def add_node( self, attrs: dict[str, Any], @@ -786,9 +906,10 @@ def add_node( else: node_id = index + write_attrs = self._flatten_attrs_for_write(attrs, self._node_attr_schemas()) node = self.Node( node_id=node_id, - **attrs, + **write_attrs, ) with Session(self._engine) as session: @@ -864,7 +985,9 @@ def bulk_add_nodes( node[DEFAULT_ATTR_KEYS.NODE_ID] = node_id node_ids.append(node_id) - self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node) + node_schemas = self._node_attr_schemas() + write_nodes = [self._flatten_attrs_for_write(node, node_schemas) for node in nodes] + self._chunked_sa_write(Session.bulk_insert_mappings, write_nodes, self.Node) if is_signal_on(self.node_added): for node_id, node_attrs in zip(node_ids, nodes, strict=True): @@ -900,7 +1023,10 @@ def remove_node(self, node_id: int) -> None: self.node_removed.emit(node_id, old_attrs) if is_signal_on(self.node_removed): - old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()} + attr_keys = self.node_attr_keys() + old_df = self.filter(node_ids=[node_id]).node_attrs(attr_keys=attr_keys) + old_row = old_df.row(0, named=True) + old_attrs = {key: old_row[key] for key in attr_keys} # Remove all edges where this node is source or target session.query(self.Edge).filter( @@ -966,6 +1092,7 @@ def add_edge( if hasattr(target_id, "item"): target_id = target_id.item() + attrs = self._flatten_attrs_for_write(attrs, self._edge_attr_schemas()) edge = self.Edge( source_id=source_id, target_id=target_id, @@ -1019,9 +1146,12 @@ def bulk_add_edges( return [] return None + edge_schemas = self._edge_attr_schemas() for edge in edges: _data_numpy_to_native(edge) + edges = [self._flatten_attrs_for_write(edge, edge_schemas) for edge in edges] + if return_ids: with Session(self._engine) as session: result = session.execute(sa.insert(self.Edge).returning(self.Edge.edge_id), edges) @@ -1178,7 +1308,8 @@ def _get_neighbors( # all columns node_columns = [self.Node] else: - node_columns = [getattr(self.Node, key) for key in attr_keys] + # Expand struct logical keys to their flat physical columns. + node_columns = self._physical_cols_for_query(attr_keys, self.Node) query = session.query(getattr(self.Edge, node_key), *node_columns) query = query.join(self.Edge, getattr(self.Edge, neighbor_key) == self.Node.node_id) @@ -1196,7 +1327,7 @@ def _get_neighbors( self.Node, ) node_df = unpickle_bytes_columns(node_df) - node_df = self._cast_array_columns(self.Node, node_df) + node_df = self._cast_columns(self.Node, node_df) if single_node: if not return_attrs: @@ -1378,9 +1509,9 @@ def node_attrs( if attr_keys is not None: # making them unique attr_keys = list(dict.fromkeys(attr_keys)) - + # Expand struct logical keys to their flat physical columns. query = query.with_only_columns( - *[getattr(self.Node, key) for key in attr_keys], + *self._physical_cols_for_query(attr_keys, self.Node), ) nodes_df = pl.read_database( @@ -1389,9 +1520,9 @@ def node_attrs( schema_overrides=self._polars_schema_override(self.Node), ) nodes_df = unpickle_bytes_columns(nodes_df) - nodes_df = self._cast_array_columns(self.Node, nodes_df) + nodes_df = self._cast_columns(self.Node, nodes_df) - # indices are included by default and must be removed + # Select using logical keys (struct columns are now reconstructed). if attr_keys is not None: nodes_df = nodes_df.select([pl.col(c) for c in attr_keys]) else: @@ -1424,8 +1555,9 @@ def edge_attrs( LOG.info("Edge attribute keys: %s", attr_keys) + # Expand struct logical keys to their flat physical columns. query = query.with_only_columns( - *[getattr(self.Edge, key) for key in attr_keys], + *self._physical_cols_for_query(attr_keys, self.Edge), ) edges_df = pl.read_database( @@ -1434,7 +1566,7 @@ def edge_attrs( schema_overrides=self._polars_schema_override(self.Edge), ) edges_df = unpickle_bytes_columns(edges_df) - edges_df = self._cast_array_columns(self.Edge, edges_df) + edges_df = self._cast_columns(self.Edge, edges_df) if unpack: edges_df = unpack_array_attrs(edges_df) @@ -1449,6 +1581,24 @@ def _node_attr_schemas(self) -> dict[str, AttrSchema]: def _edge_attr_schemas(self) -> dict[str, AttrSchema]: return self.__edge_attr_schemas + def _physical_cols_for_query( + self, + logical_keys: Sequence[str], + table_class: type[DeclarativeBase], + ) -> list[Any]: + """Return SQLAlchemy column objects for *logical_keys*, expanding struct keys + into their flat physical leaf columns so the SQL query fetches all necessary data.""" + schemas = self._attr_schemas_for_table(table_class) + cols: list[Any] = [] + for key in logical_keys: + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct): + for flat_col, _ in flatten_struct_dtype(key, schema.dtype): + cols.append(getattr(table_class, flat_col)) + else: + cols.append(getattr(table_class, key)) + return cols + def node_attr_keys(self, return_ids: bool = False) -> list[str]: """ Get the keys of the attributes of the nodes. @@ -1459,7 +1609,7 @@ def node_attr_keys(self, return_ids: bool = False) -> list[str]: Whether to include NODE_ID in the returned keys. Defaults to False. If True, NODE_ID will be included in the list. """ - keys = list(self.Node.__table__.columns.keys()) + keys = list(self._node_attr_schemas().keys()) if not return_ids and DEFAULT_ATTR_KEYS.NODE_ID in keys: keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) return keys @@ -1474,7 +1624,7 @@ def edge_attr_keys(self, return_ids: bool = False) -> list[str]: Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys. Defaults to False. If True, these ID fields will be included in the list. """ - keys = list(self.Edge.__table__.columns.keys()) + keys = list(self._edge_attr_schemas().keys()) if not return_ids: for id_key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: if id_key in keys: @@ -1507,13 +1657,19 @@ def _resolve_attr_keys( if len(attr_keys) == 0: raise ValueError("attr_keys must contain at least one column name") - missing = [key for key in attr_keys if key not in table_class.__table__.columns] + schemas = self._attr_schemas_for_table(table_class) + physical_names: list[str] = [] + for key in attr_keys: + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct): + physical_names.extend(fc for fc, _ in flatten_struct_dtype(key, schema.dtype)) + else: + physical_names.append(key) + + missing = [name for name in physical_names if name not in table_class.__table__.columns] if missing: raise ValueError(f"Columns {missing} do not exist on table {table_class.__tablename__}") - resolved_columns = [getattr(table_class, key) for key in attr_keys] - - if isinstance(attr_keys, str): - attr_keys = [attr_keys] + resolved_columns = [getattr(table_class, name) for name in physical_names] cols_fragment = "_".join(attr_keys) name = f"ix_{table_class.__tablename__.lower()}_{cols_fragment}" @@ -1667,28 +1823,24 @@ def _sqlalchemy_type_inference(self, default_value: Any) -> TypeEngine: else: raise ValueError(f"Unsupported default value type: {type(default_value)}") - def _add_new_column( + def _add_physical_column( self, table_class: type[DeclarativeBase], - schema: AttrSchema, + col_name: str, + sa_type: Any, + default_value: Any, ) -> None: - # Convert polars dtype to SQLAlchemy type - sa_type = polars_dtype_to_sqlalchemy_type(schema.dtype) - - # Handle special cases for default value encoding - default_value = schema.default_value + """Create a single physical SQL column and register it on the ORM class.""" if isinstance(sa_type, sa.PickleType) and default_value is not None: - # Pickle complex types for database storage default_value = blob_default(self._engine, cloudpickle.dumps(default_value)) - sa_column = sa.Column(schema.key, sa_type, default=default_value) + sa_column = sa.Column(col_name, sa_type, default=default_value) str_dialect_type = sa_column.type.compile(dialect=self._engine.dialect) identifier_preparer = self._engine.dialect.identifier_preparer quoted_table_name = identifier_preparer.format_table(table_class.__table__) quoted_column_name = identifier_preparer.quote(sa_column.name) - # Properly quote default values based on type if isinstance(default_value, str): quoted_default = f"'{default_value}'" elif default_value is None: @@ -1703,15 +1855,38 @@ def _add_new_column( ) LOG.info("add %s column statement:\n'%s'", table_class.__table__, add_column_stmt) - # create the new column in the database with Session(self._engine) as session: session.execute(add_column_stmt) session.commit() - # register the new column in the Node class - setattr(table_class, schema.key, sa_column) + setattr(table_class, col_name, sa_column) table_class.__table__.append_column(sa_column) + def _add_new_column( + self, + table_class: type[DeclarativeBase], + schema: AttrSchema, + ) -> None: + """Add a new attribute column (or flat leaf columns for structs) to *table_class*.""" + if isinstance(schema.dtype, pl.Struct): + # Expand struct into one physical column per leaf field. + flat_defaults = flatten_struct_value(schema.key, schema.default_value or {}, schema.dtype) + for flat_col, leaf_dtype in flatten_struct_dtype(schema.key, schema.dtype): + self._add_physical_column( + table_class, + flat_col, + polars_dtype_to_sqlalchemy_type(leaf_dtype), + flat_defaults.get(flat_col), + ) + return + + self._add_physical_column( + table_class, + schema.key, + polars_dtype_to_sqlalchemy_type(schema.dtype), + schema.default_value, + ) + def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: identifier_preparer = self._engine.dialect.identifier_preparer quoted_table_name = identifier_preparer.format_table(table_class.__table__) @@ -1749,7 +1924,12 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") node_schemas = self.__node_attr_schemas - self._drop_column(self.Node, key) + schema = node_schemas.get(key) + if schema and isinstance(schema.dtype, pl.Struct): + for flat_col, _ in flatten_struct_dtype(key, schema.dtype): + self._drop_column(self.Node, flat_col) + else: + self._drop_column(self.Node, key) node_schemas.pop(key, None) self.__node_attr_schemas = node_schemas @@ -1773,7 +1953,12 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") edge_schemas = self.__edge_attr_schemas - self._drop_column(self.Edge, key) + schema = edge_schemas.get(key) + if schema and isinstance(schema.dtype, pl.Struct): + for flat_col, _ in flatten_struct_dtype(key, schema.dtype): + self._drop_column(self.Edge, flat_col) + else: + self._drop_column(self.Edge, key) edge_schemas.pop(key, None) self.__edge_attr_schemas = edge_schemas @@ -1812,6 +1997,8 @@ def _update_table( # Handle array values with bulk_update_mappings attrs = attrs.copy() _data_numpy_to_native(attrs) + schemas = self._attr_schemas_for_table(table_class) + attrs = self._flatten_attrs_for_write(attrs, schemas) # specialized case for scalar values - use simple bulk update if all(np.isscalar(v) for v in attrs.values()): diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 8630db6e..d5cd74e3 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -225,6 +225,20 @@ def test_filter_nodes_by_membership(graph_backend: BaseGraph) -> None: assert set(np_members) == {node_b} +def test_filter_nodes_by_struct_field(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "name": pl.String})) + + node_a = graph_backend.add_node({"t": 0, "measurements": {"score": 1, "name": "A"}}) + node_b = graph_backend.add_node({"t": 1, "measurements": {"score": 2, "name": "B"}}) + node_c = graph_backend.add_node({"t": 2, "measurements": {"score": 1, "name": "C"}}) + + score_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("score") == 1).node_ids() + assert set(score_nodes) == {node_a, node_c} + + name_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("name") == "B").node_ids() + assert set(name_nodes) == {node_b} + + def test_time_points(graph_backend: BaseGraph) -> None: """Test retrieving time points.""" graph_backend.add_node({"t": 0}) @@ -1706,6 +1720,24 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: np.testing.assert_array_equal(stored_mask.mask, mask_data) +def test_sql_graph_struct_dtype_survives_reload(tmp_path: Path) -> None: + db_path = tmp_path / "struct_graph.db" + graph = SQLGraph("sqlite", str(db_path)) + graph.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "label": pl.String})) + + node_id = graph.add_node({"t": 0, "measurements": {"score": 7, "label": "A"}}) + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + + df = reloaded.node_attrs(attr_keys=["measurements"]) + assert df.schema["measurements"] == pl.Struct({"score": pl.Int64, "label": pl.String}) + assert df["measurements"].to_list() == [{"score": 7, "label": "A"}] + + ids = reloaded.filter(NodeAttr("measurements").struct.field("score") == 7).node_ids() + assert ids == [node_id] + + def test_sql_graph_max_id_restored_per_timepoint(tmp_path: Path) -> None: """Reloading a SQLGraph should respect existing max IDs per time point.""" db_path = tmp_path / "id_restore.db" diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 05338fa5..152b3d50 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -395,6 +395,71 @@ def infer_default_value_from_dtype(dtype: pl.DataType) -> Any: } +STRUCT_FIELD_SEP = "__" + + +def flatten_struct_dtype( + key: str, + dtype: pl.Struct, + sep: str = STRUCT_FIELD_SEP, +) -> list[tuple[str, pl.DataType]]: + """Recursively return ``(flat_column_name, leaf_dtype)`` for all leaves of a struct. + + Parameters + ---------- + key : str + The root column name (or already-accumulated flat prefix for nested calls). + dtype : pl.Struct + The struct dtype to flatten. + sep : str + Separator between path components. Defaults to ``STRUCT_FIELD_SEP``. + + Examples + -------- + >>> flatten_struct_dtype("m", pl.Struct({"score": pl.Int64, "label": pl.String})) + [("m__score", Int64), ("m__label", String)] + """ + results: list[tuple[str, pl.DataType]] = [] + for field_name, field_dtype in dtype.to_schema().items(): + flat_key = f"{key}{sep}{field_name}" + if isinstance(field_dtype, pl.Struct): + results.extend(flatten_struct_dtype(flat_key, field_dtype, sep)) + else: + results.append((flat_key, field_dtype)) + return results + + +def flatten_struct_value( + key: str, + value: dict, + dtype: pl.Struct, + sep: str = STRUCT_FIELD_SEP, +) -> dict: + """Flatten a struct dict value into ``{flat_col: scalar}`` pairs. + + Parameters + ---------- + key : str + The root column name. + value : dict + The struct value to flatten (may be ``None`` or empty). + dtype : pl.Struct + The struct dtype describing the expected fields. + sep : str + Separator. Defaults to ``STRUCT_FIELD_SEP``. + """ + result: dict = {} + value = value or {} + for field_name, field_dtype in dtype.to_schema().items(): + flat_key = f"{key}{sep}{field_name}" + field_val = value.get(field_name) + if isinstance(field_dtype, pl.Struct): + result.update(flatten_struct_value(flat_key, field_val or {}, field_dtype, sep)) + else: + result[flat_key] = field_val + return result + + def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: """ Convert a polars dtype to SQLAlchemy type.