diff --git a/pyproject.toml b/pyproject.toml index 4b16011..c20c7c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "pandas", "polars", "pyarrow", + "rtree", "scanpy", "scipy", "shapely", @@ -29,6 +30,29 @@ dependencies = [ "tifffile" ] +[project.optional-dependencies] +# CPU Leiden parity for fragment mode (Stage B). Optional: when absent, fragment +# mode falls back to the numpy `_threshold_cut` splitter. +cluster = [ + "leidenalg>=0.10", + "python-igraph>=0.11", +] +# SpatialData / SOPA export (`segger export --format spatialdata`). Optional: +# the base install only needs the Xenium / AnnData / merged exporters. +spatialdata = [ + "spatialdata>=0.7.2", + "spatialdata-io>=0.6.0", +] +sopa = [ + "sopa>=2.0.0", + "spatialdata>=0.7.2", +] +spatialdata-all = [ + "spatialdata>=0.7.2", + "spatialdata-io>=0.6.0", + "sopa>=2.0.0", +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/segger/cli/export.py b/src/segger/cli/export.py new file mode 100644 index 0000000..52010a7 --- /dev/null +++ b/src/segger/cli/export.py @@ -0,0 +1,247 @@ +import os +import logging +from pathlib import Path +from typing import Annotated, Literal + +from cyclopts import App, Parameter, Group, validators + +from ..utils import setup_logging + +# Parameter groups +group_io = Group( + name="I/O", + help="Related to file inputs/outputs.", + sort_key=0, +) +group_export = Group( + name="Export", + help="Export format and similarity-threshold options.", + sort_key=1, +) +group_boundary = Group( + name="Boundary", + help="Cell boundary generation options.", + sort_key=2, +) + +app_export = App(name="export", help="Export a segger segmentation to downstream formats.") + + +@app_export.command(name="export") +def export( + segmentation_path: Annotated[Path, Parameter( + help="Path to segmentation result (segger_segmentation.parquet, or a .csv/.tsv).", + alias="-s", + group=group_io, + validator=validators.Path(exists=True), + )], + source_path: Annotated[Path, Parameter( + help="Raw input data directory (Xenium/MERSCOPE/CosMX). Used as the --xenium-bundle " + "and to recover transcript_id, coordinates, and (optionally) input boundaries.", + alias="-i", + group=group_io, + validator=validators.Path(exists=True, dir_okay=True), + )], + output_directory: Annotated[Path, Parameter( + help="Output directory for exported files.", + alias="-o", + group=group_io, + )], + format: Annotated[ + Literal["xenium", "merged", "anndata", "spatialdata"], + Parameter(help="Export format.", group=group_export), + ] = "xenium", + xenium_mode: Annotated[ + Literal["transcript_assignment", "geojson", "both"], + Parameter( + help="For --format xenium: which import-segmentation inputs to write. " + "'transcript_assignment' = segmentation.csv + viz polygons; " + "'geojson' = cell polygon.geojson; 'both' = all of them.", + group=group_export, + ), + ] = "both", + cell_id_column: Annotated[str, Parameter( + help="Cell-ID column in the segmentation file (auto-falls back to common aliases).", + group=group_export, + )] = "segger_cell_id", + run_id: Annotated[str, Parameter( + help="--id passed to the printed `xeniumranger import-segmentation` command.", + group=group_export, + )] = "segger_import", + min_similarity: Annotated[float | None, Parameter( + help="Fixed similarity threshold (0-1) for the keep test, overriding the per-gene " + "Li+Yen threshold from segmentation.", + validator=validators.Number(gte=0, lte=1), + group=group_export, + )] = None, + min_similarity_shift: Annotated[float, Parameter( + help="Subtractive relaxation applied to per-gene similarity thresholds (more " + "permissive). Only effective when --min-similarity is not set.", + validator=validators.Number(gte=0, lte=1), + group=group_export, + )] = 0.0, + boundary_method: Annotated[ + Literal["delaunay", "input"], + Parameter( + help="Cell boundary source. 'delaunay' = generate from assigned transcripts " + "(our multi-core method); 'input' = use the source's boundaries.", + group=group_boundary, + ), + ] = "delaunay", + units: Annotated[ + Literal["microns", "pixels"], + Parameter(help="Coordinate units passed to import-segmentation (segger uses microns).", group=group_boundary), + ] = "microns", + num_workers: Annotated[int, Parameter( + help="Worker threads for boundary generation.", + alias="-n", + validator=validators.Number(gte=0), + group=group_boundary, + )] = 1, +): + """Export a segger segmentation to Xenium Explorer / scverse formats. + + The default ``xenium`` format writes inputs for 10x's + ``xeniumranger import-segmentation`` pipeline (whose output opens in Xenium + Explorer) and prints the command to run. Other formats: ``merged`` (transcripts + joined with assignments), ``anndata`` (cell x gene matrix), ``spatialdata`` (Zarr + for the scverse/SOPA ecosystem; needs ``pip install segger[spatialdata]``). + """ + setup_logging(level=os.environ.get("LOG_LEVEL", "WARNING")) + logger = logging.getLogger(__name__) + + import polars as pl + + # Load the segmentation table + if segmentation_path.suffix == ".parquet": + seg_df = pl.read_parquet(segmentation_path) + elif segmentation_path.suffix in {".csv", ".tsv"}: + seg_df = pl.read_csv( + segmentation_path, + separator="\t" if segmentation_path.suffix == ".tsv" else ",", + ) + else: + raise ValueError( + f"Unsupported segmentation format: {segmentation_path.suffix}. " + "Expected .parquet, .csv, or .tsv." + ) + + # Resolve the cell-ID column, normalizing to 'segger_cell_id' + effective_cell_id = cell_id_column + if effective_cell_id not in seg_df.columns: + for alias in ("segger_cell_id", "seg_cell_id", "cell_id", "segmentation_cell_id"): + if alias in seg_df.columns: + logger.warning(f"'{cell_id_column}' not found; using '{alias}'.") + effective_cell_id = alias + break + else: + raise ValueError( + "Segmentation file is missing a valid cell-ID column. Set --cell-id-column." + ) + + # Recompute the keep column from export-time threshold params + if min_similarity is not None and "segger_similarity" in seg_df.columns: + seg_df = seg_df.with_columns( + (pl.col(effective_cell_id).is_not_null() & (pl.col("segger_similarity") >= min_similarity)).alias("keep") + ) + elif min_similarity_shift > 0 and {"segger_similarity", "similarity_threshold"} <= set(seg_df.columns): + seg_df = seg_df.with_columns( + ( + pl.col(effective_cell_id).is_not_null() + & (pl.col("segger_similarity") >= (pl.col("similarity_threshold") - min_similarity_shift).clip(-1.0, 1.0)) + ).alias("keep") + ) + + if effective_cell_id != "segger_cell_id": + seg_df = seg_df.rename({effective_cell_id: "segger_cell_id"}) + + def _load_boundaries(): + from ..io import get_preprocessor + try: + return get_preprocessor(source_path).boundaries + except Exception as exc: # pragma: no cover - source-dependent + logger.warning(f"Could not load input boundaries ({exc}); generating instead.") + return None + + # Xenium Explorer via import-segmentation + if format == "xenium": + from ..export import export_xenium_import + + boundaries = _load_boundaries() if boundary_method == "input" else None + written = export_xenium_import( + seg_df, + source_path, + output_directory, + mode=xenium_mode, + cell_id_column="segger_cell_id", + boundaries=boundaries, + boundary_method=boundary_method, + units=units, + n_jobs=max(num_workers, 1), + run_id=run_id, + ) + logger.info(f"Wrote Xenium import-segmentation inputs to: {output_directory}") + for key, path in written.items(): + if key != "_commands": + logger.info(f" {key}: {path}") + print("\nNext, run Xenium Ranger (output opens in Xenium Explorer):") + for cmd in written["_commands"]: + print("\n" + cmd) + return + + # Other formats need the source transcripts + from ..io import get_preprocessor + + tx = get_preprocessor(source_path).transcripts + if isinstance(tx, pl.LazyFrame): + tx = tx.collect() + boundaries = _load_boundaries() if boundary_method == "input" else None + + if format == "merged": + from ..export import MergedTranscriptsWriter + + out = MergedTranscriptsWriter().write( + predictions=seg_df, + output_dir=output_directory, + transcripts=tx, + output_name="transcripts_segmented.parquet", + ) + logger.info(f"Wrote merged transcripts: {out}") + return + + if format == "anndata": + from ..export import AnnDataWriter + + out = AnnDataWriter().write( + predictions=seg_df, + output_dir=output_directory, + transcripts=tx, + output_name="segger_segmentation.h5ad", + ) + logger.info(f"Wrote AnnData: {out}") + return + + if format == "spatialdata": + from ..export import SpatialDataWriter + + try: + writer = SpatialDataWriter( + include_boundaries=True, + boundary_method=boundary_method, + boundary_n_jobs=max(num_workers, 1), + ) + except ImportError: + logger.error("spatialdata is not installed. Install with: pip install segger[spatialdata]") + return + + out = writer.write( + predictions=seg_df, + output_dir=output_directory, + transcripts=tx, + boundaries=boundaries, + output_name="segmentation.zarr", + ) + logger.info(f"Wrote SpatialData: {out}") + return + + raise ValueError(f"Unsupported export format: {format}") diff --git a/src/segger/cli/main.py b/src/segger/cli/main.py index 9764208..c7e4ea1 100644 --- a/src/segger/cli/main.py +++ b/src/segger/cli/main.py @@ -1,5 +1,6 @@ from cyclopts import App from .segment import segment +from .export import export from .debug import debug # CLI App @@ -8,5 +9,8 @@ # Main segmentation app.command(segment) +# Export a segmentation to Xenium Explorer / scverse formats +app.command(export) + # Debugging utilities app.command(debug) diff --git a/src/segger/export/__init__.py b/src/segger/export/__init__.py new file mode 100644 index 0000000..4e0eb88 --- /dev/null +++ b/src/segger/export/__init__.py @@ -0,0 +1,126 @@ +"""Export segmentation results to downstream formats. + +- **Xenium Explorer** via 10x's ``xeniumranger import-segmentation`` workflow + (Baysor-style transcript assignment + viz polygons, or cell/nucleus GeoJSON). +- **Merged transcripts** (original transcripts joined with segger assignments). +- **AnnData** (cell x gene matrix). +- **SpatialData** Zarr (scverse / SOPA ecosystem; optional, requires ``spatialdata``). + +Heavy / optional modules are imported lazily so the base install stays light. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +__all__ = [ + # Boundaries + "BoundaryIdentification", + "generate_boundary", + "generate_boundaries", + "extract_largest_polygon", + # Xenium import-segmentation + "export_xenium_import", + "write_baysor_csv", + "write_viz_polygons", + "write_cell_geojson", + "build_import_command", + # Output-format registry + "OutputFormat", + "OutputWriter", + "get_writer", + "register_writer", + "write_all_formats", + # Writers + "MergedTranscriptsWriter", + "SeggerRawWriter", + "merge_predictions_with_transcripts", + "AnnDataWriter", + "build_anndata_table", + # SpatialData / SOPA (optional) + "SpatialDataWriter", + "write_spatialdata", + "validate_sopa_compatibility", + "export_for_sopa", + "sopa_to_segger_input", + "check_sopa_installation", +] + +if TYPE_CHECKING: # pragma: no cover + from .boundary import ( + BoundaryIdentification, + generate_boundary, + generate_boundaries, + extract_largest_polygon, + ) + from .xenium_import import ( + export_xenium_import, + write_baysor_csv, + write_viz_polygons, + write_cell_geojson, + build_import_command, + ) + from .output_formats import ( + OutputFormat, + OutputWriter, + get_writer, + register_writer, + write_all_formats, + ) + from .merged_writer import ( + MergedTranscriptsWriter, + SeggerRawWriter, + merge_predictions_with_transcripts, + ) + from .anndata_writer import AnnDataWriter, build_anndata_table + from .spatialdata_writer import SpatialDataWriter, write_spatialdata + from .sopa_compat import ( + validate_sopa_compatibility, + export_for_sopa, + sopa_to_segger_input, + check_sopa_installation, + ) + + +_LAZY = { + "BoundaryIdentification": "boundary", + "generate_boundary": "boundary", + "generate_boundaries": "boundary", + "extract_largest_polygon": "boundary", + "export_xenium_import": "xenium_import", + "write_baysor_csv": "xenium_import", + "write_viz_polygons": "xenium_import", + "write_cell_geojson": "xenium_import", + "build_import_command": "xenium_import", + "OutputFormat": "output_formats", + "OutputWriter": "output_formats", + "get_writer": "output_formats", + "register_writer": "output_formats", + "write_all_formats": "output_formats", + "MergedTranscriptsWriter": "merged_writer", + "SeggerRawWriter": "merged_writer", + "merge_predictions_with_transcripts": "merged_writer", + "AnnDataWriter": "anndata_writer", + "build_anndata_table": "anndata_writer", + # Optional (require ``spatialdata`` / ``sopa``) + "SpatialDataWriter": "spatialdata_writer", + "write_spatialdata": "spatialdata_writer", + "validate_sopa_compatibility": "sopa_compat", + "export_for_sopa": "sopa_compat", + "sopa_to_segger_input": "sopa_compat", + "check_sopa_installation": "sopa_compat", +} + + +def __getattr__(name: str): + module_name = _LAZY.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + import importlib + + module = importlib.import_module(f".{module_name}", __name__) + return getattr(module, name) + + +def __dir__(): + return sorted(__all__) diff --git a/src/segger/export/anndata_writer.py b/src/segger/export/anndata_writer.py new file mode 100644 index 0000000..32a0b1d --- /dev/null +++ b/src/segger/export/anndata_writer.py @@ -0,0 +1,362 @@ +"""Write segmentation results as AnnData (.h5ad). + +This writer builds a cell x gene count matrix from transcript assignments +and saves it as an AnnData object. The output can also be embedded as a +table in SpatialData. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal, Optional, Union + +import numpy as np +import pandas as pd +import polars as pl +from anndata import AnnData +from scipy import sparse as sp + +from segger.export.output_formats import OutputFormat, register_writer +from segger.export.merged_writer import merge_predictions_with_transcripts +from segger.utils.fragment_outputs import ( + FRAGMENT_FLAG_COLUMN, + OBJECT_GROUP_COLUMN, + OBJECT_TYPE_CELL, + OBJECT_TYPE_COLUMN, + OBJECT_TYPE_FRAGMENT, + split_h5ad_output_paths, + split_transcripts_by_object_type, + with_fragment_annotations, +) + + +def build_anndata_table( + transcripts: pl.DataFrame, + var_transcripts: Optional[pl.DataFrame] = None, + cell_id_column: str = "segger_cell_id", + feature_column: str = "feature_name", + x_column: Optional[str] = "x", + y_column: Optional[str] = "y", + z_column: Optional[str] = "z", + unassigned_value: Union[int, str, None] = -1, + region: Optional[str] = None, + region_key: Optional[str] = None, + obs_index_as_str: bool = False, + obs_metadata: Literal["full", "object_type", "none"] = "full", +) -> AnnData: + """Build AnnData from assigned transcripts. + + Parameters + ---------- + transcripts + Transcript DataFrame with segmentation assignments. + cell_id_column + Column with assigned cell IDs. + feature_column + Column with gene/feature names. + x_column, y_column, z_column + Coordinate columns (optional). If present, centroids are stored in + ``obsm["X_spatial"]``. + unassigned_value + Marker for unassigned transcripts (filtered out). + region, region_key + SpatialData table linkage metadata. + obs_index_as_str + If True, cast cell IDs to string for ``obs`` index. + obs_metadata + Controls assignment metadata columns written to ``obs``: + - ``"full"``: ``segger_object_type``, ``segger_object_group``, + ``segger_is_fragment`` + - ``"object_type"``: only ``segger_object_type`` + - ``"none"``: no assignment metadata columns + """ + if obs_metadata not in {"full", "object_type", "none"}: + raise ValueError( + f"Unsupported obs_metadata mode: {obs_metadata!r}. " + "Use one of: 'full', 'object_type', 'none'." + ) + + if cell_id_column not in transcripts.columns: + raise ValueError(f"Missing cell_id column: {cell_id_column}") + if feature_column not in transcripts.columns: + raise ValueError(f"Missing feature column: {feature_column}") + + transcripts = with_fragment_annotations( + transcripts, + cell_id_column=cell_id_column, + unassigned_value=unassigned_value, + ) + var_source = var_transcripts if var_transcripts is not None else transcripts + if feature_column not in var_source.columns: + raise ValueError(f"Missing feature column: {feature_column}") + + assigned = ( + transcripts + .filter(pl.col(cell_id_column).is_not_null()) + .filter(pl.col(OBJECT_TYPE_COLUMN) != "unassigned") + ) + # Filter by keep column when present (GMM threshold) + if "keep" in assigned.columns: + assigned = assigned.filter(pl.col("keep")) + + # Gene list from all transcripts (even if no assignments) + var_idx = ( + var_source + .select(feature_column) + .unique() + .sort(feature_column) + .get_column(feature_column) + .to_list() + ) + + if assigned.height == 0: + obs_index = pd.Index([], name=cell_id_column) + if obs_index_as_str: + var_index = pd.Index([str(v) for v in var_idx], name=feature_column) + else: + var_index = pd.Index(var_idx, name=feature_column) + X = sp.csr_matrix((0, len(var_index))) + adata = AnnData(X=X, obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=var_index)) + if obs_metadata in {"full", "object_type"}: + adata.obs[OBJECT_TYPE_COLUMN] = pd.Series([], dtype="object") + if obs_metadata == "full": + adata.obs[OBJECT_GROUP_COLUMN] = pd.Series([], dtype="object") + adata.obs[FRAGMENT_FLAG_COLUMN] = pd.Series([], dtype=bool) + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + return adata + + feature_idx = ( + assigned + .select(feature_column) + .unique() + .sort(feature_column) + .with_row_index(name="_fid") + ) + cell_idx = ( + assigned + .select(cell_id_column) + .unique() + .sort(cell_id_column) + .with_row_index(name="_cid") + ) + + mapped = ( + assigned + .join(feature_idx, on=feature_column) + .join(cell_idx, on=cell_id_column) + ) + counts = ( + mapped + .group_by(["_cid", "_fid"]) + .agg(pl.len().alias("_count")) + ) + ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T + rows = ijv[0].astype(np.int64, copy=False) + cols = ijv[1].astype(np.int64, copy=False) + data = ijv[2].astype(np.int64, copy=False) + + n_cells = cell_idx.height + n_genes = feature_idx.height + X = sp.coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)).tocsr() + + obs_ids = cell_idx.get_column(cell_id_column).to_list() + var_ids = feature_idx.get_column(feature_column).to_list() + if obs_index_as_str: + obs_ids = [str(v) for v in obs_ids] + var_ids = [str(v) for v in var_ids] + + adata = AnnData( + X=X, + obs=pd.DataFrame(index=pd.Index(obs_ids, name=cell_id_column)), + var=pd.DataFrame(index=pd.Index(var_ids, name=feature_column)), + ) + + if obs_metadata != "none": + agg_exprs = [ + pl.col(OBJECT_TYPE_COLUMN).first().alias(OBJECT_TYPE_COLUMN), + ] + if obs_metadata == "full": + agg_exprs.extend( + [ + pl.col(OBJECT_GROUP_COLUMN).first().alias(OBJECT_GROUP_COLUMN), + pl.col(FRAGMENT_FLAG_COLUMN).max().alias(FRAGMENT_FLAG_COLUMN), + ] + ) + obs_meta = ( + assigned + .group_by(cell_id_column) + .agg(agg_exprs) + .to_pandas() + .set_index(cell_id_column) + .reindex(adata.obs.index) + ) + adata.obs[OBJECT_TYPE_COLUMN] = obs_meta[OBJECT_TYPE_COLUMN].fillna( + OBJECT_TYPE_CELL + ) + if obs_metadata == "full": + adata.obs[OBJECT_GROUP_COLUMN] = obs_meta[OBJECT_GROUP_COLUMN].fillna("cells") + adata.obs[FRAGMENT_FLAG_COLUMN] = ( + obs_meta[FRAGMENT_FLAG_COLUMN].fillna(False).astype(bool) + ) + + # Add centroid coordinates if present + if x_column in assigned.columns and y_column in assigned.columns: + coords_cols = [x_column, y_column] + if z_column and z_column in assigned.columns: + coords_cols.append(z_column) + centroids = ( + assigned + .group_by(cell_id_column) + .agg([pl.col(c).mean().alias(c) for c in coords_cols]) + ) + centroids_pd = ( + centroids + .to_pandas() + .set_index(cell_id_column) + .reindex(adata.obs.index) + ) + adata.obsm["X_spatial"] = centroids_pd[coords_cols].to_numpy() + + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + + return adata + + +@register_writer(OutputFormat.ANNDATA) +class AnnDataWriter: + """Write segmentation results as AnnData (.h5ad).""" + + def __init__( + self, + unassigned_marker: Union[int, str, None] = -1, + compression: Optional[str] = "gzip", + compression_opts: Optional[int] = 4, + ): + self.unassigned_marker = unassigned_marker + self.compression = compression + self.compression_opts = compression_opts + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + transcripts: Optional[pl.DataFrame] = None, + output_name: str = "segger_segmentation.h5ad", + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + feature_column: str = "feature_name", + x_column: Optional[str] = "x", + y_column: Optional[str] = "y", + z_column: Optional[str] = "z", + overwrite: bool = False, + **kwargs, + ) -> Path: + """Write segmentation results to AnnData (.h5ad). + + Parameters + ---------- + predictions + Segmentation predictions. + output_dir + Output directory. + transcripts + Original transcripts DataFrame (required). + output_name + Output filename. Default "segger_segmentation.h5ad". + """ + if transcripts is None: + raise ValueError("AnnData output requires transcripts DataFrame.") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / output_name + output_paths = split_h5ad_output_paths(output_path) + + merged = merge_predictions_with_transcripts( + predictions=predictions, + transcripts=transcripts, + row_index_column=row_index_column, + cell_id_column=cell_id_column, + similarity_column=similarity_column, + unassigned_marker=self.unassigned_marker, + ) + split_frames = split_transcripts_by_object_type( + merged, + cell_id_column=cell_id_column, + unassigned_value=self.unassigned_marker, + ) + has_fragments = split_frames[OBJECT_TYPE_FRAGMENT].height > 0 + paths_to_write = [ + output_paths["combined"], + output_paths[OBJECT_TYPE_CELL], + ] + if has_fragments: + paths_to_write.append(output_paths[OBJECT_TYPE_FRAGMENT]) + + if not overwrite: + for path in paths_to_write: + if path.exists(): + raise FileExistsError( + f"Output path exists: {path}. " + "Use overwrite=True to replace." + ) + if (not has_fragments) and output_paths[OBJECT_TYPE_FRAGMENT].exists(): + raise FileExistsError( + f"Output path exists: {output_paths[OBJECT_TYPE_FRAGMENT]}. " + "Remove stale fragment output or use overwrite=True." + ) + elif (not has_fragments) and output_paths[OBJECT_TYPE_FRAGMENT].exists(): + output_paths[OBJECT_TYPE_FRAGMENT].unlink() + + adata = build_anndata_table( + transcripts=split_frames["all"], + var_transcripts=merged, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=self.unassigned_marker, + obs_metadata="object_type", + ) + cells_adata = build_anndata_table( + transcripts=split_frames[OBJECT_TYPE_CELL], + var_transcripts=merged, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=self.unassigned_marker, + obs_metadata="none", + ) + + write_kwargs = {} + if self.compression is not None: + write_kwargs["compression"] = self.compression + if self.compression_opts is not None: + write_kwargs["compression_opts"] = self.compression_opts + + adata.write_h5ad(output_paths["combined"], **write_kwargs) + cells_adata.write_h5ad(output_paths[OBJECT_TYPE_CELL], **write_kwargs) + if has_fragments: + fragments_adata = build_anndata_table( + transcripts=split_frames[OBJECT_TYPE_FRAGMENT], + var_transcripts=merged, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=self.unassigned_marker, + obs_metadata="none", + ) + fragments_adata.write_h5ad(output_paths[OBJECT_TYPE_FRAGMENT], **write_kwargs) + return output_path diff --git a/src/segger/export/boundary.py b/src/segger/export/boundary.py new file mode 100644 index 0000000..2f5bcba --- /dev/null +++ b/src/segger/export/boundary.py @@ -0,0 +1,525 @@ +"""Delaunay triangulation-based cell boundary generation. + +This module provides sophisticated boundary extraction using Delaunay triangulation +with iterative edge refinement and cycle detection. This produces more accurate +cell boundaries than simple convex hulls. +""" + +from typing import Iterable, Tuple, Union +from concurrent.futures import ThreadPoolExecutor +import geopandas as gpd +import numpy as np +import pandas as pd +import polars as pl +import rtree.index +from scipy.spatial import Delaunay +from shapely.geometry import MultiPolygon, Polygon +from tqdm import tqdm + + +def vector_angle(v1: np.ndarray, v2: np.ndarray) -> float: + """Calculate angle between two vectors in degrees. + + Parameters + ---------- + v1 : np.ndarray + First vector. + v2 : np.ndarray + Second vector. + + Returns + ------- + float + Angle in degrees. + """ + dot_product = np.dot(v1, v2) + magnitude_v1 = np.linalg.norm(v1) + magnitude_v2 = np.linalg.norm(v2) + cos_angle = np.clip(dot_product / (magnitude_v1 * magnitude_v2 + 1e-8), -1.0, 1.0) + return np.degrees(np.arccos(cos_angle)) + + +def triangle_angles_from_points( + points: np.ndarray, + triangles: np.ndarray, +) -> np.ndarray: + """Calculate angles for all triangles in a Delaunay triangulation. + + Parameters + ---------- + points : np.ndarray + Point coordinates, shape (N, 2). + triangles : np.ndarray + Triangle vertex indices, shape (M, 3). + + Returns + ------- + np.ndarray + Angles for each triangle vertex, shape (M, 3). + """ + # Vectorized angle computation for all triangles + p1 = points[triangles[:, 0]] + p2 = points[triangles[:, 1]] + p3 = points[triangles[:, 2]] + + v1 = p2 - p1 + v2 = p3 - p1 + v3 = p3 - p2 + + def _angles(u: np.ndarray, v: np.ndarray) -> np.ndarray: + dot = (u * v).sum(axis=1) + denom = (np.linalg.norm(u, axis=1) * np.linalg.norm(v, axis=1)) + 1e-8 + cos = np.clip(dot / denom, -1.0, 1.0) + return np.degrees(np.arccos(cos)) + + a = _angles(v1, v2) + b = _angles(-v1, v3) + c = _angles(-v2, -v3) + return np.stack([a, b, c], axis=1) + + +def dfs(v: int, graph: dict, path: list, colors: dict) -> None: + """Depth-first search for cycle detection. + + Parameters + ---------- + v : int + Current vertex. + graph : dict + Adjacency list representation of graph. + path : list + Current path being built. + colors : dict + Vertex visit status (0=unvisited, 1=visited). + """ + colors[v] = 1 + path.append(v) + for d in graph[v]: + if colors[d] == 0: + dfs(d, graph, path, colors) + + +class BoundaryIdentification: + """Delaunay triangulation-based polygon boundary extraction. + + This class implements a two-phase iterative algorithm for extracting + cell boundaries from transcript point clouds: + + 1. Phase 1: Remove long boundary edges (> 2 * d_max) + 2. Phase 2: Remove boundary edges with extreme angles + + Parameters + ---------- + data : np.ndarray + 2D point coordinates, shape (N, 2). + """ + + def __init__(self, data: np.ndarray): + self.graph = None + self.edges = {} + self.d = Delaunay(data) + self.d_max = self.calculate_d_max(self.d.points) + self.generate_edges() + + def generate_edges(self) -> None: + """Generate edge dictionary from Delaunay triangulation.""" + d = self.d + edges = {} + angles = triangle_angles_from_points(d.points, d.simplices) + + for index, simplex in enumerate(d.simplices): + for p in range(3): + edge = tuple(sorted((simplex[p], simplex[(p + 1) % 3]))) + if edge not in edges: + edges[edge] = {"simplices": {}} + edges[edge]["simplices"][index] = angles[index][(p + 2) % 3] + + edges_coordinates = d.points[np.array(list(edges.keys()))] + edges_length = np.sqrt( + (edges_coordinates[:, 1, 0] - edges_coordinates[:, 0, 0]) ** 2 + + (edges_coordinates[:, 1, 1] - edges_coordinates[:, 0, 1]) ** 2 + ) + + for edge, coords, length in zip(edges, edges_coordinates, edges_length): + edges[edge]["coords"] = coords + edges[edge]["length"] = length + + self.edges = edges + + def calculate_part_1(self, plot: bool = False) -> None: + """Phase 1: Remove long boundary edges iteratively. + + Removes edges longer than 2 * d_max from the boundary. + + Parameters + ---------- + plot : bool + Whether to generate visualization (not implemented). + """ + edges = self.edges + d = self.d + d_max = self.d_max + + boundary_edges = [edge for edge in edges if len(edges[edge]["simplices"]) < 2] + + flag = True + while flag: + flag = False + next_boundary_edges = [] + + for current_edge in boundary_edges: + if current_edge not in edges: + continue + + if edges[current_edge]["length"] > 2 * d_max: + if len(edges[current_edge]["simplices"].keys()) == 0: + del edges[current_edge] + continue + + simplex_id = list(edges[current_edge]["simplices"].keys())[0] + simplex = d.simplices[simplex_id] + + for edge in self.get_edges_from_simplex(simplex): + if edge != current_edge: + edges[edge]["simplices"].pop(simplex_id) + next_boundary_edges.append(edge) + + del edges[current_edge] + flag = True + else: + next_boundary_edges.append(current_edge) + + boundary_edges = next_boundary_edges + + def calculate_part_2(self, plot: bool = False) -> None: + """Phase 2: Remove boundary edges with extreme angles. + + Removes edges where the opposite angle is too large, indicating + a concave region that should be excluded. + + Parameters + ---------- + plot : bool + Whether to generate visualization (not implemented). + """ + edges = self.edges + d = self.d + d_max = self.d_max + + boundary_edges = [edge for edge in edges if len(edges[edge]["simplices"]) < 2] + boundary_edges_length = len(boundary_edges) + next_boundary_edges = [] + + while len(next_boundary_edges) != boundary_edges_length: + next_boundary_edges = [] + + for current_edge in boundary_edges: + if current_edge not in edges: + continue + + if len(edges[current_edge]["simplices"].keys()) == 0: + del edges[current_edge] + continue + + simplex_id = list(edges[current_edge]["simplices"].keys())[0] + simplex = d.simplices[simplex_id] + + # Remove if edge is long with large angle, or if angle is very obtuse + if ( + edges[current_edge]["length"] > 1.5 * d_max + and edges[current_edge]["simplices"][simplex_id] > 90 + ) or edges[current_edge]["simplices"][simplex_id] > 180 - 180 / 16: + + for edge in self.get_edges_from_simplex(simplex): + if edge != current_edge: + edges[edge]["simplices"].pop(simplex_id) + next_boundary_edges.append(edge) + + del edges[current_edge] + else: + next_boundary_edges.append(current_edge) + + boundary_edges_length = len(boundary_edges) + boundary_edges = next_boundary_edges + + def find_cycles(self) -> Union[Polygon, MultiPolygon, None]: + """Find boundary cycles and convert to Shapely geometry. + + Returns + ------- + Union[Polygon, MultiPolygon, None] + Polygon if single cycle, MultiPolygon if multiple, None on error. + """ + e = self.edges + boundary_edges = [edge for edge in e if len(e[edge]["simplices"]) < 2] + self.graph = self.generate_graph(boundary_edges) + cycles = self.get_cycles(self.graph) + + try: + if len(cycles) == 1: + geom = Polygon(self.d.points[cycles[0]]) + else: + geom = MultiPolygon( + [Polygon(self.d.points[c]) for c in cycles if len(c) >= 3] + ) + except Exception: + return None + + return geom + + @staticmethod + def calculate_d_max(points: np.ndarray) -> float: + """Calculate maximum nearest-neighbor distance. + + Parameters + ---------- + points : np.ndarray + Point coordinates, shape (N, 2). + + Returns + ------- + float + Maximum nearest-neighbor distance. + """ + index = rtree.index.Index() + for i, p in enumerate(points): + index.insert(i, p[[0, 1, 0, 1]]) + + short_edges = [] + for i, p in enumerate(points): + res = list(index.nearest(p[[0, 1, 0, 1]], 2))[-1] + short_edges.append([i, res]) + + nearest_points = points[short_edges] + nearest_dists = np.sqrt( + (nearest_points[:, 0, 0] - nearest_points[:, 1, 0]) ** 2 + + (nearest_points[:, 0, 1] - nearest_points[:, 1, 1]) ** 2 + ) + return nearest_dists.max() + + @staticmethod + def get_edges_from_simplex(simplex: np.ndarray) -> list: + """Extract edge tuples from a triangle simplex. + + Parameters + ---------- + simplex : np.ndarray + Triangle vertex indices, shape (3,). + + Returns + ------- + list + List of edge tuples. + """ + edges = [] + for p in range(3): + edges.append(tuple(sorted((simplex[p], simplex[(p + 1) % 3])))) + return edges + + @staticmethod + def generate_graph(edges: list) -> dict: + """Generate adjacency list from edge list. + + Parameters + ---------- + edges : list + List of edge tuples. + + Returns + ------- + dict + Adjacency list representation. + """ + vertices = set() + for edge in edges: + vertices.add(edge[0]) + vertices.add(edge[1]) + + vertices = sorted(list(vertices)) + graph = {v: [] for v in vertices} + + for e in edges: + graph[e[0]].append(e[1]) + graph[e[1]].append(e[0]) + + return graph + + @staticmethod + def get_cycles(graph: dict) -> list: + """Find all connected components (cycles) in boundary graph. + + Parameters + ---------- + graph : dict + Adjacency list representation. + + Returns + ------- + list + List of cycles (each cycle is a list of vertex indices). + """ + colors = {v: 0 for v in graph} + cycles = [] + + for v in graph.keys(): + if colors[v] == 0: + cycle = [] + dfs(v, graph, cycle, colors) + cycles.append(cycle) + + return cycles + + +def generate_boundary( + df: Union[pd.DataFrame, pl.DataFrame], + x: str = "x", + y: str = "y", +) -> Union[Polygon, MultiPolygon, None]: + """Generate boundary polygon for a single cell's transcripts. + + Uses Delaunay triangulation with iterative edge refinement to produce + more accurate boundaries than simple convex hulls. + + Parameters + ---------- + df : Union[pd.DataFrame, pl.DataFrame] + Transcript data with x, y coordinates. + x : str + Column name for x coordinate. + y : str + Column name for y coordinate. + + Returns + ------- + Union[Polygon, MultiPolygon, None] + Cell boundary geometry, or None if insufficient points. + """ + # Convert Polars to pandas if needed + if isinstance(df, pl.DataFrame): + df = df.to_pandas() + + if len(df) < 3: + return None + + bi = BoundaryIdentification(df[[x, y]].values) + bi.calculate_part_1(plot=False) + bi.calculate_part_2(plot=False) + return bi.find_cycles() + + +def generate_boundaries( + df: Union[pd.DataFrame, pl.DataFrame], + x: str = "x", + y: str = "y", + cell_id: str = "seg_cell_id", + n_jobs: int = 1, + chunksize: int = 8, + progress: bool = True, +) -> gpd.GeoDataFrame: + """Generate boundaries for all cells in a segmentation result. + + Parameters + ---------- + df : Union[pd.DataFrame, pl.DataFrame] + Transcript data with cell assignments. + x : str + Column name for x coordinate. + y : str + Column name for y coordinate. + cell_id : str + Column name for cell ID. + + Returns + ------- + gpd.GeoDataFrame + GeoDataFrame with cell_id, length, and geometry columns. + """ + def iter_groups() -> Tuple[Iterable[Tuple[object, np.ndarray]], int]: + if isinstance(df, pl.DataFrame): + grouped = df.group_by(cell_id).agg( + [ + pl.col(x).list().alias("_x"), + pl.col(y).list().alias("_y"), + ] + ) + total = grouped.height + + def _gen(): + for cid, xs, ys in grouped.iter_rows(): + yield cid, np.column_stack((xs, ys)) + + return _gen(), total + + group_df = df.groupby(cell_id) + total = group_df.ngroups + + def _gen(): + for cid, t in group_df: + yield cid, t[[x, y]].to_numpy() + + return _gen(), total + + def _compute_one(item: Tuple[object, np.ndarray]) -> Tuple[object, int, Union[Polygon, MultiPolygon, None]]: + cid, points = item + n_unique_points = np.unique(points, axis=0).shape[0] + if n_unique_points < 3: + return cid, n_unique_points, None + try: + bi = BoundaryIdentification(points) + bi.calculate_part_1(plot=False) + bi.calculate_part_2(plot=False) + geom = bi.find_cycles() + except Exception: + geom = None + return cid, n_unique_points, geom + + group_iter, total = iter_groups() + res = [] + + if n_jobs and n_jobs > 1: + with ThreadPoolExecutor(max_workers=n_jobs) as ex: + iterator = ex.map(_compute_one, group_iter, chunksize=chunksize) + if progress: + iterator = tqdm(iterator, total=total, desc="Generating boundaries") + for cid, length, geom in iterator: + res.append({"cell_id": cid, "length": length, "geom": geom}) + else: + iterator = group_iter + if progress: + iterator = tqdm(iterator, total=total, desc="Generating boundaries") + for item in iterator: + cid, length, geom = _compute_one(item) + res.append({"cell_id": cid, "length": length, "geom": geom}) + + return gpd.GeoDataFrame( + data=[[b["cell_id"], b["length"]] for b in res], + geometry=[b["geom"] for b in res], + columns=["cell_id", "length"], + ) + + +def extract_largest_polygon( + geom: Union[Polygon, MultiPolygon, None], +) -> Union[Polygon, None]: + """Extract the largest polygon from a geometry. + + Parameters + ---------- + geom : Union[Polygon, MultiPolygon, None] + Input geometry. + + Returns + ------- + Union[Polygon, None] + Largest polygon, or None if input is None. + """ + if geom is None: + return None + if getattr(geom, "is_empty", False): + return None + if isinstance(geom, MultiPolygon): + candidates = [p for p in geom.geoms if p is not None and not p.is_empty] + if not candidates: + return None + return max(candidates, key=lambda p: p.area) + return geom diff --git a/src/segger/export/merged_writer.py b/src/segger/export/merged_writer.py new file mode 100644 index 0000000..8c0c2cb --- /dev/null +++ b/src/segger/export/merged_writer.py @@ -0,0 +1,311 @@ +"""Write segmentation results merged back to original transcripts. + +This writer joins segmentation predictions with the original transcript data, +producing a single output file that contains all original columns plus +the segmentation results (segger_cell_id, segger_similarity). + +Usage +----- +>>> from segger.export.merged_writer import MergedTranscriptsWriter +>>> writer = MergedTranscriptsWriter( +... original_transcripts_path=Path("data/transcripts.parquet") +... ) +>>> output_path = writer.write(predictions, Path("output/")) + +The output file contains: +- All original transcript columns +- segger_cell_id: Assigned cell ID (-1 for unassigned) +- segger_similarity: Assignment confidence score (0.0 for unassigned) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Optional, Union + +import polars as pl + +from segger.export.output_formats import OutputFormat, register_writer +from segger.utils.fragment_outputs import ( + FRAGMENT_FLAG_COLUMN, + OBJECT_GROUP_COLUMN, + OBJECT_TYPE_COLUMN, + with_fragment_annotations, +) + +if TYPE_CHECKING: + pass + + +@register_writer(OutputFormat.SEGGER_RAW) +class SeggerRawWriter: + """Write raw Segger prediction output (default format). + + This writer outputs just the predictions DataFrame without merging + with original transcripts. This is the default Segger output format. + + Output columns: + - row_index: Original transcript row index + - segger_cell_id: Assigned cell ID + - segger_similarity: Assignment confidence score + """ + + def __init__( + self, + compression: Literal["snappy", "gzip", "lz4", "zstd", "none"] = "snappy", + ): + """Initialize the raw writer. + + Parameters + ---------- + compression + Parquet compression algorithm. Default is 'snappy'. + """ + self.compression = compression if compression != "none" else None + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + output_name: str = "predictions.parquet", + **kwargs, + ) -> Path: + """Write predictions to Parquet file. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + output_dir + Output directory. + output_name + Output filename. Default is 'predictions.parquet'. + + Returns + ------- + Path + Path to the written Parquet file. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / output_name + predictions.write_parquet(output_path, compression=self.compression) + + return output_path + + +@register_writer(OutputFormat.MERGED_TRANSCRIPTS) +class MergedTranscriptsWriter: + """Write segmentation results merged with original transcripts. + + This writer joins predictions with original transcript data, producing + a complete output file with all original columns plus segmentation results. + + Output columns: + - All original transcript columns + - segger_cell_id: Assigned cell ID (configurable marker for unassigned) + - segger_similarity: Assignment confidence score + + Parameters + ---------- + original_transcripts_path + Path to the original transcripts file (Parquet or CSV). + If not provided, must be passed to write() via kwargs. + unassigned_marker + Value to use for unassigned transcripts. Default is -1. + Can be int, str, or None. + include_similarity + Whether to include the similarity score column. Default True. + compression + Parquet compression algorithm. Default is 'snappy'. + """ + + def __init__( + self, + original_transcripts_path: Optional[Path] = None, + unassigned_marker: Union[int, str, None] = -1, + include_similarity: bool = True, + compression: Literal["snappy", "gzip", "lz4", "zstd", "none"] = "snappy", + ): + self.original_transcripts_path = ( + Path(original_transcripts_path) if original_transcripts_path else None + ) + self.unassigned_marker = unassigned_marker + self.include_similarity = include_similarity + self.compression = compression if compression != "none" else None + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + output_name: str = "transcripts_segmented.parquet", + transcripts: Optional[pl.DataFrame] = None, + original_transcripts_path: Optional[Path] = None, + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + **kwargs, + ) -> Path: + """Merge predictions with original transcripts and write to file. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. Must contain: + - row_index: Original transcript row index + - segger_cell_id: Assigned cell ID + - segger_similarity: Assignment confidence score (optional) + output_dir + Output directory. + output_name + Output filename. Default is 'transcripts_segmented.parquet'. + transcripts + Original transcripts DataFrame. If provided, used instead of + loading from original_transcripts_path. + original_transcripts_path + Path to original transcripts. Overrides constructor parameter. + row_index_column + Column name for row index in predictions. Default 'row_index'. + cell_id_column + Column name for cell ID in predictions. Default 'segger_cell_id'. + similarity_column + Column name for similarity in predictions. Default 'segger_similarity'. + + Returns + ------- + Path + Path to the written Parquet file. + + Raises + ------ + ValueError + If no transcripts source is provided. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get original transcripts + if transcripts is not None: + original = transcripts + else: + path = original_transcripts_path or self.original_transcripts_path + if path is None: + raise ValueError( + "No original transcripts provided. Either pass 'transcripts' " + "DataFrame or specify 'original_transcripts_path'." + ) + original = self._load_transcripts(path) + + merged = merge_predictions_with_transcripts( + predictions=predictions, + transcripts=original, + row_index_column=row_index_column, + cell_id_column=cell_id_column, + similarity_column=similarity_column if self.include_similarity else "", + unassigned_marker=self.unassigned_marker, + ) + + # Write output + output_path = output_dir / output_name + merged.write_parquet(output_path, compression=self.compression) + + return output_path + + def _load_transcripts(self, path: Path) -> pl.DataFrame: + """Load transcripts from file. + + Parameters + ---------- + path + Path to transcripts file (Parquet or CSV). + + Returns + ------- + pl.DataFrame + Loaded transcripts. + """ + path = Path(path) + suffix = path.suffix.lower() + + if suffix == ".parquet": + return pl.read_parquet(path) + elif suffix in (".csv", ".tsv"): + separator = "\t" if suffix == ".tsv" else "," + return pl.read_csv(path, separator=separator) + else: + # Try Parquet first, then CSV + try: + return pl.read_parquet(path) + except Exception: + return pl.read_csv(path) + + +def merge_predictions_with_transcripts( + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + unassigned_marker: Union[int, str, None] = -1, +) -> pl.DataFrame: + """Merge predictions with transcripts (functional interface). + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + transcripts + Original transcripts DataFrame. + row_index_column + Column name for row index. + cell_id_column + Column name for cell ID in predictions. + similarity_column + Column name for similarity in predictions. + unassigned_marker + Value for unassigned transcripts. + + Returns + ------- + pl.DataFrame + Merged DataFrame with all original columns plus predictions. + + Examples + -------- + >>> merged = merge_predictions_with_transcripts(predictions, transcripts) + >>> print(merged.columns) + ['row_index', 'x', 'y', 'feature_name', 'segger_cell_id', 'segger_similarity'] + """ + # Prepare predictions + pred_cols = [row_index_column, cell_id_column] + if similarity_column and similarity_column in predictions.columns: + pred_cols.append(similarity_column) + for column in ("keep", "similarity_threshold", FRAGMENT_FLAG_COLUMN, OBJECT_TYPE_COLUMN, OBJECT_GROUP_COLUMN): + if column in predictions.columns: + pred_cols.append(column) + + pred_subset = predictions.select(pred_cols) + + # Add row_index if missing + if row_index_column not in transcripts.columns: + transcripts = transcripts.with_row_index(name=row_index_column) + + # Join + merged = transcripts.join(pred_subset, on=row_index_column, how="left") + + # Fill unassigned + if unassigned_marker is not None: + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(unassigned_marker) + ) + if similarity_column and similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + return with_fragment_annotations( + merged, + cell_id_column=cell_id_column, + unassigned_value=unassigned_marker, + ) diff --git a/src/segger/export/output_formats.py b/src/segger/export/output_formats.py new file mode 100644 index 0000000..d08a990 --- /dev/null +++ b/src/segger/export/output_formats.py @@ -0,0 +1,309 @@ +"""Output format definitions and writer registry for segmentation results. + +This module provides: +- OutputFormat enum for available output formats +- OutputWriter protocol for implementing format-specific writers +- Factory function to get the appropriate writer for a format + +Available formats: +- SEGGER_RAW: Default Segger output (predictions parquet) +- MERGED_TRANSCRIPTS: Original transcripts merged with assignments +- SPATIALDATA: SpatialData Zarr format for scverse ecosystem +- ANNDATA: AnnData (.h5ad) cell x gene matrix + +Usage +----- +>>> from segger.export.output_formats import OutputFormat, get_writer +>>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS) +>>> writer.write(predictions, transcripts, output_dir) +""" + +from __future__ import annotations + +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import geopandas as gpd + import polars as pl + + +class OutputFormat(str, Enum): + """Available output formats for segmentation results. + + Attributes + ---------- + SEGGER_RAW : str + Default Segger output format. Writes predictions as Parquet file + with columns: row_index, segger_cell_id, segger_similarity. + + MERGED_TRANSCRIPTS : str + Merged transcripts format. Original transcript data with segmentation + results joined (segger_cell_id, segger_similarity columns added). + + SPATIALDATA : str + SpatialData Zarr format. Creates a .zarr store compatible with + the scverse ecosystem, containing transcripts and optional boundaries. + + ANNDATA : str + AnnData format. Creates a .h5ad file with a cell x gene matrix + derived from transcript assignments. + """ + + SEGGER_RAW = "segger_raw" + MERGED_TRANSCRIPTS = "merged" + SPATIALDATA = "spatialdata" + ANNDATA = "anndata" + + @classmethod + def from_string(cls, value: str) -> "OutputFormat": + """Parse OutputFormat from string, case-insensitive. + + Parameters + ---------- + value + Format name ('segger_raw', 'merged', 'spatialdata', 'anndata', or 'all'). + + Returns + ------- + OutputFormat + Corresponding enum value. + + Raises + ------ + ValueError + If value is not a valid format name. + """ + value_lower = value.lower().strip() + + # Handle aliases + aliases = { + "raw": cls.SEGGER_RAW, + "segger": cls.SEGGER_RAW, + "default": cls.SEGGER_RAW, + "merge": cls.MERGED_TRANSCRIPTS, + "merged": cls.MERGED_TRANSCRIPTS, + "transcripts": cls.MERGED_TRANSCRIPTS, + "sdata": cls.SPATIALDATA, + "zarr": cls.SPATIALDATA, + "h5ad": cls.ANNDATA, + "ann": cls.ANNDATA, + "anndata": cls.ANNDATA, + } + + if value_lower in aliases: + return aliases[value_lower] + + # Try direct match + for fmt in cls: + if fmt.value == value_lower: + return fmt + + valid = [f.value for f in cls] + list(aliases.keys()) + raise ValueError( + f"Unknown output format: '{value}'. " + f"Valid formats: {sorted(set(valid))}" + ) + + +@runtime_checkable +class OutputWriter(Protocol): + """Protocol for output format writers. + + Implementations must provide a `write` method that writes segmentation + results to the specified output directory. + """ + + def write( + self, + predictions: "pl.DataFrame", + output_dir: Path, + **kwargs: Any, + ) -> Path: + """Write segmentation results to output format. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. Must contain: + - row_index: Original transcript row index + - segger_cell_id: Assigned cell ID (or -1/None for unassigned) + - segger_similarity: Assignment confidence score + + output_dir + Directory to write output files. + + **kwargs + Format-specific options (e.g., transcripts, boundaries). + + Returns + ------- + Path + Path to the primary output file/directory. + """ + ... + + +# Registry of output writers by format +_OUTPUT_WRITERS: dict[OutputFormat, type] = {} + + +def register_writer(fmt: OutputFormat): + """Decorator to register an output writer class. + + Parameters + ---------- + fmt + Output format this writer handles. + + Returns + ------- + decorator + Class decorator that registers the writer. + + Examples + -------- + >>> @register_writer(OutputFormat.MERGED_TRANSCRIPTS) + ... class MergedTranscriptsWriter: + ... def write(self, predictions, output_dir, **kwargs): + ... ... + """ + def decorator(cls): + _OUTPUT_WRITERS[fmt] = cls + return cls + return decorator + + +def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: + """Get an output writer for the specified format. + + Parameters + ---------- + fmt + Output format (enum or string). + **init_kwargs + Keyword arguments passed to the writer constructor. + + Returns + ------- + OutputWriter + Writer instance for the specified format. + + Raises + ------ + ValueError + If format is not recognized or writer not registered. + + Examples + -------- + >>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS, unassigned_marker=-1) + >>> writer.write(predictions, Path("output/")) + """ + if isinstance(fmt, str): + fmt = OutputFormat.from_string(fmt) + + if fmt not in _OUTPUT_WRITERS: + raise ValueError( + f"No writer registered for format: {fmt.value}. " + f"Available formats: {[f.value for f in _OUTPUT_WRITERS.keys()]}" + ) + + writer_cls = _OUTPUT_WRITERS[fmt] + return writer_cls(**init_kwargs) + + +def get_all_writers(**init_kwargs: Any) -> dict[OutputFormat, OutputWriter]: + """Get writers for all registered formats. + + Parameters + ---------- + **init_kwargs + Keyword arguments passed to each writer constructor. + + Returns + ------- + dict[OutputFormat, OutputWriter] + Dictionary mapping formats to writer instances. + """ + return {fmt: get_writer(fmt, **init_kwargs) for fmt in _OUTPUT_WRITERS} + + +def write_all_formats( + predictions: "pl.DataFrame", + output_dir: Path, + **kwargs: Any, +) -> dict[OutputFormat, Path]: + """Write segmentation results in all available formats. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + output_dir + Base output directory. Subdirectories may be created for each format. + **kwargs + Additional arguments passed to each writer (transcripts, boundaries, etc.). + + Returns + ------- + dict[OutputFormat, Path] + Dictionary mapping formats to output paths. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + results = {} + for fmt, writer in get_all_writers().items(): + try: + path = writer.write(predictions, output_dir, **kwargs) + results[fmt] = path + except Exception as e: + # Log error but continue with other formats + import warnings + warnings.warn( + f"Failed to write {fmt.value} format: {e}", + UserWarning, + stacklevel=2, + ) + + return results + + +# Import writers to register them (done at end to avoid circular imports) +def _register_builtin_writers(): + """Register built-in output writers. + + Called lazily to avoid import errors if optional dependencies are missing. + """ + # Import here to register writers via decorators + from segger.export import merged_writer # noqa: F401 + from segger.export import anndata_writer # noqa: F401 + + # SpatialData writer is optional + try: + from segger.export import spatialdata_writer # noqa: F401 + except ImportError: + pass + + +# Lazy registration on first use +_writers_registered = False + + +def _ensure_writers_registered(): + """Ensure built-in writers are registered.""" + global _writers_registered + if not _writers_registered: + _register_builtin_writers() + _writers_registered = True + + +# Override get_writer to ensure registration +_original_get_writer = get_writer + + +def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: + """Get an output writer for the specified format.""" + _ensure_writers_registered() + return _original_get_writer(fmt, **init_kwargs) diff --git a/src/segger/export/xenium_import.py b/src/segger/export/xenium_import.py new file mode 100644 index 0000000..a86dcef --- /dev/null +++ b/src/segger/export/xenium_import.py @@ -0,0 +1,408 @@ +"""Export segger results for 10x's ``xeniumranger import-segmentation`` workflow. + +segger is a transcript-assignment segmentation method, so the natural hand-off to +Xenium Explorer is 10x Genomics' ``import-segmentation`` pipeline, which re-quantifies +transcripts against an imported segmentation and regenerates a bundle that Xenium +Explorer can open. This module turns a ``segger_segmentation.parquet`` (per-transcript +``row_index``/``segger_cell_id``) into ``import-segmentation``-ready inputs: + +- **Transcript assignment (Baysor-style):** ``segmentation.csv`` with the required + ``transcript_id``, ``cell``, ``is_noise`` columns, plus ``segmentation_polygons.json`` + (cell polygons for visualization). Imported with + ``--transcript-assignment``/``--viz-polygons``. +- **Cell/nucleus polygons:** ``polygon.geojson`` (``objectType="cell"`` features), + imported with ``--cells``/``--nuclei``. + +The segmentation parquet carries no coordinates, so transcript ``transcript_id``, micron +coordinates and gene are recovered from the source Xenium bundle's ``transcripts.parquet`` +by joining on ``row_index`` (segger assigns ``row_index`` over the raw, unfiltered +transcripts, so the positional join is exact). + +References +---------- +https://www.10xgenomics.com/support/software/xenium-ranger/latest/analysis/running-pipelines/XR-import-segmentation +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Literal, Optional, Union + +import polars as pl + +logger = logging.getLogger(__name__) + +XeniumMode = Literal["transcript_assignment", "geojson", "both"] +Units = Literal["microns", "pixels"] + +#: Baysor convention: cell ``0`` denotes noise / unassigned transcripts. +NOISE_CELL = 0 + +# Default column names in a raw Xenium ``transcripts.parquet`` (see XeniumTranscriptFields). +TRANSCRIPT_ID_COLUMN = "transcript_id" +X_COLUMN = "x_location" +Y_COLUMN = "y_location" +Z_COLUMN = "z_location" +FEATURE_COLUMN = "feature_name" + + +def _load_source_transcripts( + source_path: Path, + *, + transcript_id_column: str, + x_column: str, + y_column: str, + z_column: str, + feature_column: str, +) -> pl.DataFrame: + """Read the source Xenium ``transcripts.parquet`` with a positional ``row_index``. + + Parameters + ---------- + source_path : Path + Path to the raw Xenium bundle (directory containing ``transcripts.parquet``). + transcript_id_column, x_column, y_column, z_column, feature_column : str + Column names in the raw transcripts file. + + Returns + ------- + pl.DataFrame + ``row_index`` plus whichever of the requested columns are present. + """ + source_path = Path(source_path) + tx_path = source_path if source_path.suffix == ".parquet" else source_path / "transcripts.parquet" + if not tx_path.exists(): + raise FileNotFoundError( + f"Could not find a Xenium 'transcripts.parquet' at {tx_path}. The Xenium " + "import path needs the raw bundle to recover transcript_id and coordinates." + ) + raw = pl.read_parquet(tx_path).with_row_index(name="row_index") + + wanted = [transcript_id_column, x_column, y_column, z_column, feature_column] + present = [c for c in wanted if c in raw.columns] + if transcript_id_column not in raw.columns: + logger.warning( + "Column '%s' not found in %s; falling back to row_index as transcript_id.", + transcript_id_column, + tx_path.name, + ) + return raw.select(["row_index", *present]) + + +def _build_assignment( + seg_df: pl.DataFrame, + raw: pl.DataFrame, + *, + cell_id_column: str, +) -> pl.DataFrame: + """Join predictions onto the source transcripts and derive ``cell``/``is_noise``. + + A transcript is *assigned* when it has a non-negative ``segger_cell_id`` and passes + the ``keep`` test (recomputed upstream from the similarity threshold). Everything else + -- including transcripts filtered out before segmentation -- is marked ``is_noise``. + """ + keep_present = "keep" in seg_df.columns + select_cols = ["row_index", cell_id_column] + (["keep"] if keep_present else []) + seg = seg_df.select([c for c in select_cols if c in seg_df.columns]) + + merged = raw.join(seg, on="row_index", how="left") + assigned = pl.col(cell_id_column).is_not_null() & (pl.col(cell_id_column) >= 0) + if keep_present: + assigned = assigned & pl.col("keep").fill_null(False) + + return merged.with_columns( + pl.when(assigned).then(pl.col(cell_id_column).cast(pl.Int64)).otherwise(NOISE_CELL).alias("cell"), + (~assigned).alias("is_noise"), + ) + + +def write_baysor_csv( + assignment: pl.DataFrame, + output_dir: Path, + *, + transcript_id_column: str = TRANSCRIPT_ID_COLUMN, + x_column: str, + y_column: str, + z_column: Optional[str] = None, + feature_column: Optional[str] = None, + filename: str = "segmentation.csv", +) -> Path: + """Write the Baysor-style transcript-assignment CSV (``transcript_id``, ``cell``, ``is_noise``).""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + tid = transcript_id_column if transcript_id_column in assignment.columns else "row_index" + exprs = [ + pl.col(tid).alias("transcript_id"), + pl.col("cell"), + pl.col("is_noise"), + ] + for src, name in ((x_column, "x"), (y_column, "y"), (z_column, "z"), (feature_column, "gene")): + if src and src in assignment.columns: + exprs.append(pl.col(src).alias(name)) + + path = output_dir / filename + assignment.select(exprs).write_csv(path) + logger.info("Wrote transcript assignment: %s", path) + return path + + +def _polygon_feature(geom, cell_id: int, *, object_type: Optional[str] = None) -> Optional[dict]: + """Build a GeoJSON Polygon feature, or ``None`` if the geometry is degenerate.""" + from .boundary import extract_largest_polygon + + poly = extract_largest_polygon(geom) + if poly is None or poly.is_empty: + return None + coords = [[float(x), float(y)] for x, y in poly.exterior.coords] + # A valid closed ring needs >= 4 coordinates (>= 3 distinct vertices + closure); + # polygons with < 3 vertices crash Xenium Explorer v3.0. + if len(coords) < 4: + return None + feature = { + "type": "Feature", + "id": int(cell_id), + "geometry": {"type": "Polygon", "coordinates": [coords]}, + "properties": {"cell": int(cell_id)}, + } + if object_type is not None: + feature["properties"]["objectType"] = object_type + return feature + + +def _cell_polygons( + assignment: pl.DataFrame, + *, + x_column: str, + y_column: str, + n_jobs: int, + progress: bool, +): + """Generate one boundary polygon per assigned cell (our multi-core Delaunay method).""" + from .boundary import generate_boundaries + + assigned = assignment.filter(~pl.col("is_noise")) + if assigned.height == 0: + return [] + # Pass pandas to generate_boundaries (its pandas groupby path is polars-version safe). + gdf = generate_boundaries( + assigned.select(["cell", x_column, y_column]).to_pandas(), + x=x_column, + y=y_column, + cell_id="cell", + n_jobs=n_jobs, + progress=progress, + ) + return list(zip(gdf["cell_id"].tolist(), gdf.geometry.tolist())) + + +def write_viz_polygons( + assignment: pl.DataFrame, + output_dir: Path, + *, + x_column: str, + y_column: str, + n_jobs: int = 1, + progress: bool = True, + filename: str = "segmentation_polygons.json", +) -> Path: + """Write ``segmentation_polygons.json`` (FeatureCollection) for ``--viz-polygons``. + + Polygons are generated from each cell's assigned transcripts, so every emitted cell + has >=1 transcript (Xenium Ranger errors otherwise). Degenerate (<3 vertex) cells are + dropped. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + features = [] + for cid, geom in _cell_polygons(assignment, x_column=x_column, y_column=y_column, n_jobs=n_jobs, progress=progress): + feature = _polygon_feature(geom, int(cid)) + if feature is not None: + features.append(feature) + + path = output_dir / filename + path.write_text(json.dumps({"type": "FeatureCollection", "features": features})) + logger.info("Wrote %d viz polygons: %s", len(features), path) + return path + + +def write_cell_geojson( + assignment: pl.DataFrame, + output_dir: Path, + *, + x_column: str, + y_column: str, + boundaries=None, + boundary_method: Literal["input", "delaunay"] = "delaunay", + n_jobs: int = 1, + progress: bool = True, + filename: str = "polygon.geojson", +) -> Path: + """Write a cell-polygon ``FeatureCollection`` (``objectType="cell"``) for ``--cells``. + + Uses the source's input boundaries when ``boundary_method="input"`` and they are + available, otherwise generates them from assigned transcripts. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + features = [] + if boundary_method == "input" and boundaries is not None and len(boundaries) > 0: + id_col = "cell_id" if "cell_id" in boundaries.columns else boundaries.columns[0] + for cid, geom in zip(boundaries[id_col].tolist(), boundaries.geometry.tolist()): + try: + cid_int = int(cid) + except (TypeError, ValueError): + continue + feature = _polygon_feature(geom, cid_int, object_type="cell") + if feature is not None: + features.append(feature) + else: + for cid, geom in _cell_polygons(assignment, x_column=x_column, y_column=y_column, n_jobs=n_jobs, progress=progress): + feature = _polygon_feature(geom, int(cid), object_type="cell") + if feature is not None: + features.append(feature) + + path = output_dir / filename + path.write_text(json.dumps({"type": "FeatureCollection", "features": features})) + logger.info("Wrote %d cell polygons: %s", len(features), path) + return path + + +def build_import_command( + *, + mode: Literal["transcript_assignment", "geojson"], + run_id: str, + source_path: Path, + files: dict, + units: Units = "microns", + localcores: int = 16, + localmem: int = 128, +) -> str: + """Build a copy-pasteable ``xeniumranger import-segmentation`` command.""" + parts = [ + "xeniumranger import-segmentation", + f"--id={run_id}", + f"--xenium-bundle={source_path}", + ] + if mode == "transcript_assignment": + parts.append(f"--transcript-assignment={files['csv'].name}") + parts.append(f"--viz-polygons={files['viz'].name}") + else: + parts.append(f"--cells={files['cells'].name}") + if files.get("nuclei") is not None: + parts.append(f"--nuclei={files['nuclei'].name}") + parts.append(f"--units={units}") + parts.append(f"--localcores={localcores}") + parts.append(f"--localmem={localmem}") + return " \\\n ".join(parts) + + +def export_xenium_import( + seg_df: pl.DataFrame, + source_path: Union[str, Path], + output_dir: Union[str, Path], + *, + mode: XeniumMode = "both", + cell_id_column: str = "segger_cell_id", + transcript_id_column: str = TRANSCRIPT_ID_COLUMN, + x_column: str = X_COLUMN, + y_column: str = Y_COLUMN, + z_column: str = Z_COLUMN, + feature_column: str = FEATURE_COLUMN, + boundaries=None, + boundary_method: Literal["input", "delaunay"] = "delaunay", + units: Units = "microns", + n_jobs: int = 1, + run_id: str = "segger_import", + progress: bool = True, +) -> dict: + """Write 10x ``import-segmentation`` inputs from a segger segmentation table. + + Parameters + ---------- + seg_df : pl.DataFrame + Segmentation table (``row_index``, ``segger_cell_id``, optional ``keep``). + source_path : str or Path + Raw Xenium bundle (used as ``--xenium-bundle`` and to recover transcript_id/coords). + output_dir : str or Path + Where to write the import files. + mode : {"transcript_assignment", "geojson", "both"} + Which import inputs to produce. + + Returns + ------- + dict + Mapping of artifact name -> written :class:`~pathlib.Path`. + """ + source_path = Path(source_path) + output_dir = Path(output_dir) + + raw = _load_source_transcripts( + source_path, + transcript_id_column=transcript_id_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + feature_column=feature_column, + ) + assignment = _build_assignment(seg_df, raw, cell_id_column=cell_id_column) + + written: dict = {} + commands: list[str] = [] + + if mode in ("transcript_assignment", "both"): + csv_path = write_baysor_csv( + assignment, + output_dir, + transcript_id_column=transcript_id_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + feature_column=feature_column, + ) + viz_path = write_viz_polygons( + assignment, output_dir, x_column=x_column, y_column=y_column, n_jobs=n_jobs, progress=progress + ) + written["segmentation_csv"] = csv_path + written["viz_polygons"] = viz_path + commands.append( + build_import_command( + mode="transcript_assignment", + run_id=run_id, + source_path=source_path, + files={"csv": csv_path, "viz": viz_path}, + units=units, + ) + ) + + if mode in ("geojson", "both"): + cells_path = write_cell_geojson( + assignment, + output_dir, + x_column=x_column, + y_column=y_column, + boundaries=boundaries, + boundary_method=boundary_method, + n_jobs=n_jobs, + progress=progress, + ) + written["cell_geojson"] = cells_path + commands.append( + build_import_command( + mode="geojson", + run_id=run_id, + source_path=source_path, + files={"cells": cells_path}, + units=units, + ) + ) + + logger.info( + "Run Xenium Ranger to import into Xenium Explorer:\n\n%s\n", "\n\nor\n\n".join(commands) + ) + written["_commands"] = commands + return written diff --git a/src/segger/utils/__init__.py b/src/segger/utils/__init__.py new file mode 100644 index 0000000..3ae6718 --- /dev/null +++ b/src/segger/utils/__init__.py @@ -0,0 +1,56 @@ +"""Utility modules for Segger.""" + +from segger.utils._logging import setup_logging, MemFilter +from segger.utils.optional_deps import ( + # Availability flags + SPATIALDATA_AVAILABLE, + SPATIALDATA_IO_AVAILABLE, + SOPA_AVAILABLE, + # Import functions (raise ImportError if missing) + require_spatialdata, + require_spatialdata_io, + require_sopa, + # Decorators for functions requiring optional deps + requires_spatialdata, + requires_spatialdata_io, + requires_sopa, + # Warning functions for soft failures + warn_spatialdata_unavailable, + warn_spatialdata_io_unavailable, + warn_sopa_unavailable, + warn_rapids_unavailable, + # RAPIDS helpers + require_rapids, + # Version utilities + get_spatialdata_version, + get_sopa_version, + check_spatialdata_version, +) + +__all__ = [ + # Logging + "setup_logging", + "MemFilter", + # Availability flags + "SPATIALDATA_AVAILABLE", + "SPATIALDATA_IO_AVAILABLE", + "SOPA_AVAILABLE", + # Import functions + "require_spatialdata", + "require_spatialdata_io", + "require_sopa", + # Decorators + "requires_spatialdata", + "requires_spatialdata_io", + "requires_sopa", + # Warning functions + "warn_spatialdata_unavailable", + "warn_spatialdata_io_unavailable", + "warn_sopa_unavailable", + "warn_rapids_unavailable", + "require_rapids", + # Version utilities + "get_spatialdata_version", + "get_sopa_version", + "check_spatialdata_version", +] diff --git a/src/segger/utils.py b/src/segger/utils/_logging.py similarity index 100% rename from src/segger/utils.py rename to src/segger/utils/_logging.py diff --git a/src/segger/utils/fragment_outputs.py b/src/segger/utils/fragment_outputs.py new file mode 100644 index 0000000..77fe2b7 --- /dev/null +++ b/src/segger/utils/fragment_outputs.py @@ -0,0 +1,133 @@ +"""Helpers for classifying fragment assignments across export formats.""" + +from __future__ import annotations + +from pathlib import Path + +import polars as pl + +FRAGMENT_PREFIX = "fragment-" + +FRAGMENT_FLAG_COLUMN = "segger_is_fragment" +OBJECT_TYPE_COLUMN = "segger_object_type" +OBJECT_GROUP_COLUMN = "segger_object_group" + +OBJECT_TYPE_CELL = "cell" +OBJECT_TYPE_FRAGMENT = "fragment" +OBJECT_TYPE_UNASSIGNED = "unassigned" + + +def object_group_label(object_type: str) -> str: + """Return the human-facing grouping label for an assignment type.""" + if object_type == OBJECT_TYPE_CELL: + return "cells" + if object_type == OBJECT_TYPE_FRAGMENT: + return "fragments" + return OBJECT_TYPE_UNASSIGNED + + +def with_fragment_annotations( + frame: pl.DataFrame, + cell_id_column: str = "segger_cell_id", + unassigned_value: int | str | None = None, +) -> pl.DataFrame: + """Annotate transcript assignments with fragment metadata columns.""" + if cell_id_column not in frame.columns: + raise ValueError(f"Missing cell_id column: {cell_id_column}") + + cell_id_text = pl.col(cell_id_column).cast(pl.Utf8) + is_unassigned = pl.col(cell_id_column).is_null() + if unassigned_value is not None: + is_unassigned = is_unassigned | (cell_id_text == str(unassigned_value)) + + is_fragment = (~is_unassigned) & cell_id_text.fill_null("").str.starts_with(FRAGMENT_PREFIX) + object_type = ( + pl.when(is_unassigned) + .then(pl.lit(OBJECT_TYPE_UNASSIGNED)) + .when(is_fragment) + .then(pl.lit(OBJECT_TYPE_FRAGMENT)) + .otherwise(pl.lit(OBJECT_TYPE_CELL)) + ) + + return frame.with_columns( + [ + is_fragment.alias(FRAGMENT_FLAG_COLUMN), + object_type.alias(OBJECT_TYPE_COLUMN), + ( + pl.when(object_type == OBJECT_TYPE_CELL) + .then(pl.lit(object_group_label(OBJECT_TYPE_CELL))) + .when(object_type == OBJECT_TYPE_FRAGMENT) + .then(pl.lit(object_group_label(OBJECT_TYPE_FRAGMENT))) + .otherwise(pl.lit(OBJECT_TYPE_UNASSIGNED)) + ).alias(OBJECT_GROUP_COLUMN), + ] + ) + + +def split_transcripts_by_object_type( + transcripts: pl.DataFrame, + cell_id_column: str = "segger_cell_id", + unassigned_value: int | str | None = -1, +) -> dict[str, pl.DataFrame]: + """Split transcript assignments into cells and fragments.""" + annotated = with_fragment_annotations( + transcripts, + cell_id_column=cell_id_column, + unassigned_value=unassigned_value, + ) + assigned = annotated.filter(pl.col(OBJECT_TYPE_COLUMN) != OBJECT_TYPE_UNASSIGNED) + # Filter by keep column when present (GMM threshold) + if "keep" in assigned.columns: + assigned = assigned.filter(pl.col("keep").fill_null(False)) + return { + "all": assigned, + OBJECT_TYPE_CELL: assigned.filter(pl.col(OBJECT_TYPE_COLUMN) == OBJECT_TYPE_CELL), + OBJECT_TYPE_FRAGMENT: assigned.filter( + pl.col(OBJECT_TYPE_COLUMN) == OBJECT_TYPE_FRAGMENT + ), + } + + +def annotate_pandas_object_types( + frame, + cell_id_column: str = "segger_cell_id", + unassigned_value: int | str | None = -1, +): + """Annotate a pandas frame with fragment metadata columns.""" + import pandas as pd + + if cell_id_column not in frame.columns: + raise ValueError(f"Missing cell_id column: {cell_id_column}") + + result = frame.copy() + cell_id_text = result[cell_id_column].astype("string") + is_unassigned = result[cell_id_column].isna() + if unassigned_value is not None: + is_unassigned = is_unassigned | cell_id_text.eq(str(unassigned_value)) + + is_fragment = (~is_unassigned) & cell_id_text.fillna("").str.startswith(FRAGMENT_PREFIX) + object_type = pd.Series(OBJECT_TYPE_CELL, index=result.index, dtype="object") + object_type.loc[is_fragment] = OBJECT_TYPE_FRAGMENT + object_type.loc[is_unassigned] = OBJECT_TYPE_UNASSIGNED + + result[FRAGMENT_FLAG_COLUMN] = is_fragment.astype(bool) + result[OBJECT_TYPE_COLUMN] = object_type + result[OBJECT_GROUP_COLUMN] = result[OBJECT_TYPE_COLUMN].map(object_group_label) + return result + + +def split_h5ad_output_paths(output_path: Path) -> dict[str, Path]: + """Return the combined and split AnnData output paths.""" + output_path = Path(output_path) + stem = output_path.stem + base = stem.removesuffix("_segmentation") + if not base: + base = stem + + return { + "combined": output_path, + OBJECT_TYPE_CELL: output_path.with_name(f"{base}_cells{output_path.suffix}"), + OBJECT_TYPE_FRAGMENT: output_path.with_name( + f"{base}_fragments{output_path.suffix}" + ), + } diff --git a/src/segger/utils/optional_deps.py b/src/segger/utils/optional_deps.py new file mode 100644 index 0000000..ae5a2b7 --- /dev/null +++ b/src/segger/utils/optional_deps.py @@ -0,0 +1,461 @@ +"""Optional dependency handling with informative warnings. + +This module provides lazy import wrappers for optional dependencies +(spatialdata, spatialdata-io, sopa) with clear installation instructions +when the dependencies are not available. + +Usage +----- +Check availability: + >>> from segger.utils.optional_deps import SPATIALDATA_AVAILABLE + >>> if SPATIALDATA_AVAILABLE: + ... import spatialdata + +Require and get import (raises ImportError with instructions if missing): + >>> from segger.utils.optional_deps import require_spatialdata + >>> spatialdata = require_spatialdata() + +Decorator for functions requiring optional deps: + >>> from segger.utils.optional_deps import requires_spatialdata + >>> @requires_spatialdata + ... def my_function(): + ... import spatialdata + ... return spatialdata.SpatialData() +""" + +from __future__ import annotations + +import functools +import importlib +import importlib.util +import warnings +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +if TYPE_CHECKING: + import types + +# Type variable for decorator +F = TypeVar("F", bound=Callable[..., Any]) + + +# ----------------------------------------------------------------------------- +# Availability flags +# ----------------------------------------------------------------------------- + +def _check_spatialdata() -> bool: + """Check if spatialdata is available.""" + try: + return importlib.util.find_spec("spatialdata") is not None + except Exception: + return False + + +def _check_spatialdata_io() -> bool: + """Check if spatialdata-io is available.""" + try: + return importlib.util.find_spec("spatialdata_io") is not None + except Exception: + return False + + +def _check_sopa() -> bool: + """Check if sopa is available.""" + try: + return importlib.util.find_spec("sopa") is not None + except Exception: + return False + + +def _check_cellxgene_census() -> bool: + """Check if cellxgene-census is available.""" + try: + return importlib.util.find_spec("cellxgene_census") is not None + except Exception: + return False + + +# Availability flags (evaluated once at import time) +SPATIALDATA_AVAILABLE: bool = _check_spatialdata() +SPATIALDATA_IO_AVAILABLE: bool = _check_spatialdata_io() +SOPA_AVAILABLE: bool = _check_sopa() +CELLXGENE_CENSUS_AVAILABLE: bool = _check_cellxgene_census() + + +# ----------------------------------------------------------------------------- +# Installation instructions +# ----------------------------------------------------------------------------- + +SPATIALDATA_INSTALL_MSG = """ +spatialdata is not installed. This package is required for SpatialData I/O support. + +To install spatialdata support: + pip install segger[spatialdata] + +Or install spatialdata directly: + pip install spatialdata>=0.7.2 +""" + +SPATIALDATA_IO_INSTALL_MSG = """ +spatialdata-io is not installed. This package is required for reading platform-specific +SpatialData formats (Xenium, MERSCOPE, CosMX). + +To install spatialdata-io support: + pip install segger[spatialdata-io] + +For full SpatialData support: + pip install segger[spatialdata] + +Or install spatialdata-io directly: + pip install spatialdata-io>=0.6.0 +""" + +SOPA_INSTALL_MSG = """ +sopa is not installed. This package is required for SOPA compatibility features. + +To install SOPA support: + pip install segger[sopa] + +Or install sopa directly: + pip install sopa>=2.0.0 + +For all SpatialData features including SOPA: + pip install segger[spatialdata-all] +""" + +CELLXGENE_CENSUS_INSTALL_MSG = """ +cellxgene-census is not installed. This package is required for automatic +scRNA-seq reference fetching from CellxGENE Census. + +To install Census support: + pip install segger[census] + +Or install cellxgene-census directly: + pip install "cellxgene-census>=1.2.1" +""" + +RAPIDS_INSTALL_MSG = """ +RAPIDS GPU packages are not installed. Segger requires CuPy/cuDF/cuML/cuGraph/cuSpatial and a CUDA-enabled GPU. + +See docs/INSTALLATION.md for RAPIDS/CUDA setup. +""" + + +# ----------------------------------------------------------------------------- +# Import functions with error messages +# ----------------------------------------------------------------------------- + +def require_spatialdata() -> "types.ModuleType": + """Import and return spatialdata, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The spatialdata module. + + Raises + ------ + ImportError + If spatialdata is not installed, with installation instructions. + """ + if not SPATIALDATA_AVAILABLE: + raise ImportError(SPATIALDATA_INSTALL_MSG) + import spatialdata + return spatialdata + + +def require_spatialdata_io() -> "types.ModuleType": + """Import and return spatialdata_io, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The spatialdata_io module. + + Raises + ------ + ImportError + If spatialdata-io is not installed, with installation instructions. + """ + if not SPATIALDATA_IO_AVAILABLE: + raise ImportError(SPATIALDATA_IO_INSTALL_MSG) + import spatialdata_io + return spatialdata_io + + +def require_sopa() -> "types.ModuleType": + """Import and return sopa, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The sopa module. + + Raises + ------ + ImportError + If sopa is not installed, with installation instructions. + """ + if not SOPA_AVAILABLE: + raise ImportError(SOPA_INSTALL_MSG) + import sopa + return sopa + + +def require_cellxgene_census() -> "types.ModuleType": + """Import and return cellxgene_census, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The cellxgene_census module. + + Raises + ------ + ImportError + If cellxgene-census is not installed, with installation instructions. + """ + if not CELLXGENE_CENSUS_AVAILABLE: + raise ImportError(CELLXGENE_CENSUS_INSTALL_MSG) + import cellxgene_census + return cellxgene_census + + +# ----------------------------------------------------------------------------- +# Decorators for requiring optional dependencies +# ----------------------------------------------------------------------------- + +def requires_spatialdata(func: F) -> F: + """Decorator that raises ImportError if spatialdata is not available. + + Parameters + ---------- + func + Function that requires spatialdata. + + Returns + ------- + F + Wrapped function that checks for spatialdata before execution. + + Examples + -------- + >>> @requires_spatialdata + ... def load_from_zarr(path): + ... import spatialdata + ... return spatialdata.read_zarr(path) + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_spatialdata() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +def requires_spatialdata_io(func: F) -> F: + """Decorator that raises ImportError if spatialdata-io is not available. + + Parameters + ---------- + func + Function that requires spatialdata-io. + + Returns + ------- + F + Wrapped function that checks for spatialdata-io before execution. + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_spatialdata_io() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +def requires_sopa(func: F) -> F: + """Decorator that raises ImportError if sopa is not available. + + Parameters + ---------- + func + Function that requires sopa. + + Returns + ------- + F + Wrapped function that checks for sopa before execution. + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_sopa() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +def requires_cellxgene_census(func: F) -> F: + """Decorator that raises ImportError if cellxgene-census is not available. + + Parameters + ---------- + func + Function that requires cellxgene-census. + + Returns + ------- + F + Wrapped function that checks for cellxgene-census before execution. + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_cellxgene_census() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +# ----------------------------------------------------------------------------- +# Warning functions for soft failures +# ----------------------------------------------------------------------------- + +def warn_spatialdata_unavailable(feature: str = "SpatialData support") -> None: + """Emit a warning that spatialdata is not available. + + Parameters + ---------- + feature + Description of the feature requiring spatialdata. + """ + warnings.warn( + f"{feature} requires spatialdata. " + "Install with: pip install segger[spatialdata]", + UserWarning, + stacklevel=2, + ) + + +def warn_spatialdata_io_unavailable(feature: str = "Platform-specific SpatialData readers") -> None: + """Emit a warning that spatialdata-io is not available. + + Parameters + ---------- + feature + Description of the feature requiring spatialdata-io. + """ + warnings.warn( + f"{feature} requires spatialdata-io. " + "Install with: pip install segger[spatialdata-io]", + UserWarning, + stacklevel=2, + ) + + +def warn_sopa_unavailable(feature: str = "SOPA compatibility") -> None: + """Emit a warning that sopa is not available. + + Parameters + ---------- + feature + Description of the feature requiring sopa. + """ + warnings.warn( + f"{feature} requires sopa. " + "Install with: pip install segger[sopa]", + UserWarning, + stacklevel=2, + ) + + +def _import_optional_packages(packages: list[str]) -> tuple[dict[str, "types.ModuleType"], list[str]]: + """Import optional packages and return (modules, missing).""" + modules: dict[str, "types.ModuleType"] = {} + missing: list[str] = [] + for package in packages: + try: + modules[package] = importlib.import_module(package) + except Exception: + missing.append(package) + return modules, missing + + +def require_rapids( + packages: list[str] | None = None, + feature: str = "Segger", +) -> dict[str, "types.ModuleType"]: + """Import RAPIDS-related packages or raise with installation instructions.""" + package_list = packages or ["cupy", "cudf", "cuml", "cugraph", "cuspatial"] + modules, missing = _import_optional_packages(package_list) + if missing: + missing_list = ", ".join(missing) + raise ImportError( + f"{feature} requires RAPIDS GPU packages: {missing_list}. " + + RAPIDS_INSTALL_MSG.strip() + ) + return modules + + +def warn_rapids_unavailable( + feature: str = "Segger", + packages: list[str] | None = None, +) -> bool: + """Warn if RAPIDS-related packages are unavailable. Returns True if present.""" + package_list = packages or ["cupy", "cudf", "cuml", "cugraph", "cuspatial"] + _, missing = _import_optional_packages(package_list) + if not missing: + return True + missing_list = ", ".join(missing) + warnings.warn( + f"{feature} requires RAPIDS GPU packages ({missing_list}). " + + RAPIDS_INSTALL_MSG.strip(), + UserWarning, + stacklevel=2, + ) + return False + + +# ----------------------------------------------------------------------------- +# Version checking +# ----------------------------------------------------------------------------- + +def get_spatialdata_version() -> str | None: + """Get the installed spatialdata version, or None if not installed.""" + if not SPATIALDATA_AVAILABLE: + return None + try: + import spatialdata + return getattr(spatialdata, "__version__", "unknown") + except Exception: + return None + + +def get_sopa_version() -> str | None: + """Get the installed sopa version, or None if not installed.""" + if not SOPA_AVAILABLE: + return None + try: + import sopa + return getattr(sopa, "__version__", "unknown") + except Exception: + return None + + +def check_spatialdata_version(min_version: str = "0.7.2") -> bool: + """Check if spatialdata version meets minimum requirement. + + Parameters + ---------- + min_version + Minimum required version string. + + Returns + ------- + bool + True if version is sufficient, False otherwise. + """ + version = get_spatialdata_version() + if version is None or version == "unknown": + return False + + try: + from packaging.version import Version + return Version(version) >= Version(min_version) + except ImportError: + # Fallback to simple string comparison + return version >= min_version diff --git a/tests/test_export_cli.py b/tests/test_export_cli.py new file mode 100644 index 0000000..b7e0728 --- /dev/null +++ b/tests/test_export_cli.py @@ -0,0 +1,89 @@ +"""End-to-end tests for the ``segger export`` CLI command.""" + +from pathlib import Path + +import polars as pl +import pytest + + +def _write_source(dir_: Path, n: int = 6) -> Path: + dir_.mkdir(parents=True, exist_ok=True) + pl.DataFrame( + { + "transcript_id": [f"t{i}" for i in range(n)], + "x_location": [float(i) for i in range(n)], + "y_location": [0.0] * n, + "z_location": [0.0] * n, + "feature_name": (["g1", "g2"] * n)[:n], + } + ).write_parquet(dir_ / "transcripts.parquet") + return dir_ + + +def _write_segmentation(path: Path, cell_id_column: str = "segger_cell_id") -> Path: + pl.DataFrame( + { + "row_index": [0, 1, 2, 3, 4, 5], + cell_id_column: [1, 1, 2, 2, None, None], + "segger_similarity": [0.9, 0.8, 0.7, 0.6, 0.0, 0.0], + "similarity_threshold": [0.5] * 6, + }, + schema_overrides={cell_id_column: pl.Int64}, + ).write_parquet(path) + return path + + +def test_export_cli_xenium_transcript_assignment(tmp_path): + pytest.importorskip("geopandas") + pytest.importorskip("shapely") + pytest.importorskip("rtree") + from segger.cli.export import export + + src = _write_source(tmp_path / "src") + seg = _write_segmentation(tmp_path / "seg.parquet") + out = tmp_path / "out" + + export( + segmentation_path=seg, + source_path=src, + output_directory=out, + format="xenium", + xenium_mode="transcript_assignment", + num_workers=1, + ) + + csv = pl.read_csv(out / "segmentation.csv") + assert {"transcript_id", "cell", "is_noise"} <= set(csv.columns) + assert csv.height == 6 # one row per source transcript + assert (out / "segmentation_polygons.json").exists() + + +def test_export_cli_resolves_cell_id_alias(tmp_path): + pytest.importorskip("geopandas") + pytest.importorskip("shapely") + pytest.importorskip("rtree") + from segger.cli.export import export + + src = _write_source(tmp_path / "src") + # Segmentation stores the assignment under a non-standard column name. + seg = _write_segmentation(tmp_path / "seg.parquet", cell_id_column="seg_cell_id") + out = tmp_path / "out" + + export( + segmentation_path=seg, + source_path=src, + output_directory=out, + format="xenium", + xenium_mode="transcript_assignment", + ) + assert (out / "segmentation.csv").exists() + + +def test_export_cli_rejects_unknown_suffix(tmp_path): + from segger.cli.export import export + + src = _write_source(tmp_path / "src") + bad = tmp_path / "seg.txt" + bad.write_text("not a segmentation") + with pytest.raises(ValueError): + export(segmentation_path=bad, source_path=src, output_directory=tmp_path / "out") diff --git a/tests/test_export_xenium_import.py b/tests/test_export_xenium_import.py new file mode 100644 index 0000000..271d3aa --- /dev/null +++ b/tests/test_export_xenium_import.py @@ -0,0 +1,190 @@ +"""Smoke tests for the 10x ``import-segmentation`` export path.""" + +import json +from pathlib import Path + +import polars as pl +import pytest + +from segger.export.xenium_import import ( + _build_assignment, + _load_source_transcripts, + build_import_command, + export_xenium_import, + write_baysor_csv, +) + + +def _write_source(tmp_path: Path, df: pl.DataFrame) -> Path: + df.write_parquet(tmp_path / "transcripts.parquet") + return tmp_path + + +# --- Lightweight CSV path (polars only) -------------------------------------- + + +def test_baysor_csv_assignment_and_noise(tmp_path): + # 6 raw transcripts; segmentation only covers row_index 0..4 (row 5 is filtered out). + src = _write_source( + tmp_path / "src", + pl.DataFrame( + { + "transcript_id": [f"t{i}" for i in range(6)], + "x_location": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], + "y_location": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "z_location": [0.0] * 6, + "feature_name": ["g1", "g1", "g2", "g2", "g3", "g3"], + } + ), + ) + seg = pl.DataFrame( + { + "row_index": [0, 1, 2, 3, 4], + "segger_cell_id": [1, 1, 2, None, 2], + "segger_similarity": [0.9] * 5, + }, + schema_overrides={"segger_cell_id": pl.Int64}, + ) + + raw = _load_source_transcripts( + src, + transcript_id_column="transcript_id", + x_column="x_location", + y_column="y_location", + z_column="z_location", + feature_column="feature_name", + ) + assignment = _build_assignment(seg, raw, cell_id_column="segger_cell_id") + csv_path = write_baysor_csv( + assignment, + tmp_path / "out", + transcript_id_column="transcript_id", + x_column="x_location", + y_column="y_location", + z_column="z_location", + feature_column="feature_name", + ) + + df = pl.read_csv(csv_path) + # Required Baysor columns + assert {"transcript_id", "cell", "is_noise"} <= set(df.columns) + assert df.height == 6 # all raw transcripts represented + # row 3 (null assignment) and row 5 (not in segmentation) are noise + noise = df.filter(pl.col("is_noise")) + assert set(noise["transcript_id"].to_list()) == {"t3", "t5"} + # assigned cells are exactly {1, 2} + assigned = df.filter(~pl.col("is_noise")) + assert set(assigned["cell"].to_list()) == {1, 2} + + +def test_min_similarity_marks_low_confidence_as_noise(tmp_path): + src = _write_source( + tmp_path / "src", + pl.DataFrame( + { + "transcript_id": ["a", "b"], + "x_location": [0.0, 1.0], + "y_location": [0.0, 0.0], + "z_location": [0.0, 0.0], + "feature_name": ["g", "g"], + } + ), + ) + # 'keep' precomputed by the CLI: second transcript fails the threshold. + seg = pl.DataFrame( + { + "row_index": [0, 1], + "segger_cell_id": [1, 1], + "segger_similarity": [0.9, 0.1], + "keep": [True, False], + }, + schema_overrides={"segger_cell_id": pl.Int64}, + ) + raw = _load_source_transcripts( + src, + transcript_id_column="transcript_id", + x_column="x_location", + y_column="y_location", + z_column="z_location", + feature_column="feature_name", + ) + assignment = _build_assignment(seg, raw, cell_id_column="segger_cell_id") + rows = {r["transcript_id"]: r["is_noise"] for r in assignment.to_dicts()} + assert rows["a"] is False + assert rows["b"] is True + + +def test_build_import_command_forms(): + ta = build_import_command( + mode="transcript_assignment", + run_id="demo", + source_path=Path("/data/xenium"), + files={"csv": Path("/o/segmentation.csv"), "viz": Path("/o/segmentation_polygons.json")}, + units="microns", + ) + assert "xeniumranger import-segmentation" in ta + assert "--transcript-assignment=segmentation.csv" in ta + assert "--viz-polygons=segmentation_polygons.json" in ta + assert "--units=microns" in ta + + geo = build_import_command( + mode="geojson", + run_id="demo", + source_path=Path("/data/xenium"), + files={"cells": Path("/o/polygon.geojson")}, + units="microns", + ) + assert "--cells=polygon.geojson" in geo + + +# --- Full path incl. polygon generation (needs the geo stack) ---------------- + + +def test_export_both_writes_valid_polygons(tmp_path): + pytest.importorskip("geopandas") + pytest.importorskip("shapely") + pytest.importorskip("rtree") + + # Three cells, each a 4-point square cluster (>=3 non-collinear points). + def square(cx, cy): + return [(cx, cy), (cx + 1, cy), (cx + 1, cy + 1), (cx, cy + 1)] + + pts = square(0, 0) + square(20, 0) + square(0, 20) + cells = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] + src = _write_source( + tmp_path / "src", + pl.DataFrame( + { + "transcript_id": [f"t{i}" for i in range(12)], + "x_location": [float(x) for x, _ in pts], + "y_location": [float(y) for _, y in pts], + "z_location": [0.0] * 12, + "feature_name": ["g"] * 12, + } + ), + ) + seg = pl.DataFrame( + {"row_index": list(range(12)), "segger_cell_id": cells, "segger_similarity": [0.9] * 12}, + schema_overrides={"segger_cell_id": pl.Int64}, + ) + + out = tmp_path / "out" + written = export_xenium_import( + seg, src, out, mode="both", n_jobs=1, run_id="demo", progress=False + ) + + assert written["segmentation_csv"].exists() + assert written["viz_polygons"].exists() + assert written["cell_geojson"].exists() + + csv = pl.read_csv(written["segmentation_csv"]) + assigned_cells = set(csv.filter(~pl.col("is_noise"))["cell"].to_list()) + + fc = json.loads(written["viz_polygons"].read_text()) + assert fc["type"] == "FeatureCollection" + assert len(fc["features"]) >= 1 + for feat in fc["features"]: + # Every visualized polygon must correspond to a cell with transcripts. + assert feat["properties"]["cell"] in assigned_cells + ring = feat["geometry"]["coordinates"][0] + assert len(ring) >= 4 # >=3 vertices + closure (Explorer crashes on fewer)