diff --git a/python/python/lance/cuvs.py b/python/python/lance/cuvs.py new file mode 100644 index 00000000000..6bc8dbd5312 --- /dev/null +++ b/python/python/lance/cuvs.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from __future__ import annotations + +import os +import tempfile +from importlib import import_module +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from pathlib import Path + + +def is_cuvs_accelerator(accelerator: object) -> bool: + return isinstance(accelerator, str) and accelerator.lower() == "cuvs" + + +def _require_lance_cuvs(): + try: + return import_module("lance_cuvs") + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "accelerator='cuvs' requires the external 'lance-cuvs' package " + "to be installed." + ) from exc + + +def build_vector_index_on_cuvs( + dataset, + column: str, + metric_type: str, + accelerator: str, + num_partitions: int, + num_sub_vectors: int, + dst_dataset_uri: str | Path | None = None, + storage_options: Optional[dict[str, str]] = None, + *, + sample_rate: int = 256, + max_iters: int = 50, + num_bits: int = 8, + batch_size: int = 1024 * 128, + filter_nan: bool = True, +): + if not is_cuvs_accelerator(accelerator): + raise ValueError("build_vector_index_on_cuvs requires accelerator='cuvs'") + + backend = _require_lance_cuvs() + artifact_uri = ( + os.fspath(dst_dataset_uri) + if dst_dataset_uri is not None + else tempfile.mkdtemp(prefix="lance-cuvs-artifact-") + ) + training = backend.train_ivf_pq( + dataset.uri, + column, + metric_type=metric_type, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + sample_rate=sample_rate, + max_iters=max_iters, + num_bits=num_bits, + filter_nan=filter_nan, + storage_options=storage_options, + ) + artifact = backend.build_ivf_pq_artifact( + dataset.uri, + column, + training=training, + artifact_uri=artifact_uri, + batch_size=batch_size, + filter_nan=filter_nan, + storage_options=storage_options, + ) + return ( + artifact.artifact_uri, + artifact.files, + training.ivf_centroids(), + training.pq_codebook(), + ) + + +def prepare_global_ivf_pq_on_cuvs( + dataset, + column: str, + num_partitions: int, + num_sub_vectors: int, + *, + distance_type: str = "l2", + accelerator: str = "cuvs", + sample_rate: int = 256, + max_iters: int = 50, + num_bits: int = 8, + filter_nan: bool = True, +): + if not is_cuvs_accelerator(accelerator): + raise ValueError("prepare_global_ivf_pq_on_cuvs requires accelerator='cuvs'") + + backend = _require_lance_cuvs() + training = backend.train_ivf_pq( + dataset.uri, + column, + metric_type=distance_type, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + sample_rate=sample_rate, + max_iters=max_iters, + num_bits=num_bits, + filter_nan=filter_nan, + ) + return { + "ivf_centroids": training.ivf_centroids(), + "pq_codebook": training.pq_codebook(), + } diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 377ac546c3a..c786b8f7cce 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -39,6 +39,7 @@ from lance.log import LOGGER from .blob import BlobFile +from .cuvs import is_cuvs_accelerator from .dependencies import ( _check_for_numpy, _check_for_torch, @@ -2918,12 +2919,14 @@ def _create_index_impl( # Handle timing for various parts of accelerated builds timers = {} + use_cuvs = is_cuvs_accelerator(accelerator) if accelerator is not None and index_type != "IVF_PQ": LOGGER.warning( "Index type %s does not support GPU acceleration; falling back to CPU", index_type, ) accelerator = None + use_cuvs = False # IMPORTANT: Distributed indexing is CPU-only. Enforce single-node when # accelerator or torch-related paths are detected. @@ -2967,57 +2970,85 @@ def _create_index_impl( index_uuid = None if accelerator is not None: - from .vector import ( - one_pass_assign_ivf_pq_on_accelerator, - one_pass_train_ivf_pq_on_accelerator, - ) - - LOGGER.info("Doing one-pass ivfpq accelerated computations") if num_partitions is None: num_rows = self.count_rows() num_partitions = _target_partition_size_to_num_partitions( num_rows, target_partition_size ) - timers["ivf+pq_train:start"] = time.time() - ( - ivf_centroids, - ivf_kmeans, - pq_codebook, - pq_kmeans_list, - ) = one_pass_train_ivf_pq_on_accelerator( - self, - column[0], - num_partitions, - metric, - accelerator, - num_sub_vectors=num_sub_vectors, - batch_size=20480, - filter_nan=filter_nan, - ) - timers["ivf+pq_train:end"] = time.time() - ivfpq_train_time = timers["ivf+pq_train:end"] - timers["ivf+pq_train:start"] - LOGGER.info("ivf+pq training time: %ss", ivfpq_train_time) - timers["ivf+pq_assign:start"] = time.time() - shuffle_output_dir, shuffle_buffers = one_pass_assign_ivf_pq_on_accelerator( - self, - column[0], - metric, - accelerator, - ivf_kmeans, - pq_kmeans_list, - batch_size=20480, - filter_nan=filter_nan, - ) - timers["ivf+pq_assign:end"] = time.time() - ivfpq_assign_time = ( - timers["ivf+pq_assign:end"] - timers["ivf+pq_assign:start"] - ) - LOGGER.info("ivf+pq transform time: %ss", ivfpq_assign_time) + if use_cuvs: + from .cuvs import build_vector_index_on_cuvs + + LOGGER.info("Doing cuVS vector backend build") + timers["ivf+pq_build:start"] = time.time() + artifact_root, _, ivf_centroids, pq_codebook = build_vector_index_on_cuvs( + self, + column[0], + metric, + accelerator, + num_partitions, + num_sub_vectors, + storage_options=storage_options, + sample_rate=kwargs.get("sample_rate", 256), + max_iters=kwargs.get("max_iters", 50), + num_bits=kwargs.get("num_bits", 8), + batch_size=1024 * 128, + filter_nan=filter_nan, + ) + kwargs["precomputed_partition_artifact_uri"] = artifact_root + timers["ivf+pq_build:end"] = time.time() + ivfpq_build_time = ( + timers["ivf+pq_build:end"] - timers["ivf+pq_build:start"] + ) + LOGGER.info("cuVS ivf+pq build time: %ss", ivfpq_build_time) + else: + from .vector import ( + one_pass_assign_ivf_pq_on_accelerator, + one_pass_train_ivf_pq_on_accelerator, + ) - kwargs["precomputed_shuffle_buffers"] = shuffle_buffers - kwargs["precomputed_shuffle_buffers_path"] = os.path.join( - shuffle_output_dir, "data" - ) + LOGGER.info("Doing one-pass ivfpq accelerated computations") + timers["ivf+pq_train:start"] = time.time() + ( + ivf_centroids, + ivf_kmeans, + pq_codebook, + pq_kmeans_list, + ) = one_pass_train_ivf_pq_on_accelerator( + self, + column[0], + num_partitions, + metric, + accelerator, + num_sub_vectors=num_sub_vectors, + batch_size=20480, + filter_nan=filter_nan, + ) + timers["ivf+pq_train:end"] = time.time() + ivfpq_train_time = ( + timers["ivf+pq_train:end"] - timers["ivf+pq_train:start"] + ) + LOGGER.info("ivf+pq training time: %ss", ivfpq_train_time) + timers["ivf+pq_assign:start"] = time.time() + shuffle_output_dir, shuffle_buffers = one_pass_assign_ivf_pq_on_accelerator( + self, + column[0], + metric, + accelerator, + ivf_kmeans, + pq_kmeans_list, + batch_size=20480, + filter_nan=filter_nan, + ) + timers["ivf+pq_assign:end"] = time.time() + ivfpq_assign_time = ( + timers["ivf+pq_assign:end"] - timers["ivf+pq_assign:start"] + ) + LOGGER.info("ivf+pq transform time: %ss", ivfpq_assign_time) + + kwargs["precomputed_shuffle_buffers"] = shuffle_buffers + kwargs["precomputed_shuffle_buffers_path"] = os.path.join( + shuffle_output_dir, "data" + ) if index_type.startswith("IVF"): if (ivf_centroids is not None) and (ivf_centroids_file is not None): raise ValueError( @@ -3173,6 +3204,13 @@ def _create_index_impl( "Temporary shuffle buffers stored at %s, you may want to delete it.", kwargs["precomputed_shuffle_buffers_path"], ) + if "precomputed_partition_artifact_uri" in kwargs.keys() and os.path.exists( + kwargs["precomputed_partition_artifact_uri"] + ): + LOGGER.info( + "Temporary precomputed partition artifact stored at %s, you may want to delete it.", + kwargs["precomputed_partition_artifact_uri"], + ) return index def create_index( @@ -3249,7 +3287,12 @@ def create_index( The number of sub-vectors for PQ (Product Quantization). accelerator : str or ``torch.Device``, optional If set, use an accelerator to speed up the training process. - Accepted accelerator: "cuda" (Nvidia GPU) and "mps" (Apple Silicon GPU). + Accepted accelerator: + + - "cuda" (Nvidia GPU) + - "mps" (Apple Silicon GPU) + - "cuvs" for the external `lance-cuvs` backend + If not set, use the CPU. index_cache_size : int, optional The size of the index cache in number of entries. Default value is 256. @@ -3318,6 +3361,11 @@ def create_index( Only 4, 8 are supported. - index_file_version The version of the index file. Default is "V3". + - precomputed_partition_artifact_uri + An advanced input produced by an external backend such as + `lance-cuvs`. When set, Lance skips its own partition assignment + and consumes the precomputed partition-local artifact during + finalization. Requires `ivf_centroids` and `pq_codebook`. Optional parameters for `IVF_RQ`: @@ -3361,8 +3409,9 @@ def create_index( Experimental Accelerator (GPU) support: - *accelerate*: use GPU to train IVF partitions. - Only supports CUDA (Nvidia) or MPS (Apple) currently. - Requires PyTorch being installed. + Supports CUDA (Nvidia) and MPS (Apple) via the built-in torch path. + `accelerator="cuvs"` delegates IVF_PQ build preparation to the + external `lance-cuvs` package. .. code-block:: python diff --git a/python/python/lance/indices/builder.py b/python/python/lance/indices/builder.py index c31ea0a7a0c..00591ead934 100644 --- a/python/python/lance/indices/builder.py +++ b/python/python/lance/indices/builder.py @@ -9,6 +9,7 @@ import numpy as np import pyarrow as pa +from lance.cuvs import is_cuvs_accelerator, prepare_global_ivf_pq_on_cuvs from lance.indices.ivf import IvfModel from lance.indices.pq import PqModel @@ -115,6 +116,11 @@ def train_ivf( self._verify_ivf_sample_rate(sample_rate, num_partitions, num_rows) distance_type = self._normalize_distance_type(distance_type) self._verify_ivf_params(num_partitions) + if is_cuvs_accelerator(accelerator): + raise NotImplementedError( + "IndicesBuilder.train_ivf does not support accelerator='cuvs'; " + "use prepare_global_ivf_pq instead" + ) if accelerator is None: from lance.lance import indices @@ -250,6 +256,25 @@ def prepare_global_ivf_pq( `IndicesBuilder.train_pq` (indices.train_pq_model). No public method names elsewhere are changed. """ + if is_cuvs_accelerator(accelerator): + if fragment_ids is not None: + raise NotImplementedError( + "fragment_ids is not supported with accelerator='cuvs'" + ) + num_rows = self._count_rows() + num_partitions = self._determine_num_partitions(num_partitions, num_rows) + num_subvectors = self._normalize_pq_params(num_subvectors, self.dimension) + return prepare_global_ivf_pq_on_cuvs( + self.dataset, + self.column[0], + num_partitions, + num_subvectors, + distance_type=distance_type, + accelerator=accelerator, + sample_rate=sample_rate, + max_iters=max_iters, + ) + # Global IVF training ivf_model = self.train_ivf( num_partitions, diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 5371aa1f2a7..f6470a5bd8e 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -13,6 +13,7 @@ from typing import Optional import lance +import lance.cuvs as lance_cuvs import numpy as np import pyarrow as pa import pyarrow.compute as pc @@ -562,6 +563,15 @@ def test_create_index_unsupported_accelerator(tmp_path): accelerator="cuda:abc", ) + with pytest.raises(ValueError): + dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=16, + accelerator="cuvs:0", + ) + def test_create_index_accelerator_fallback(tmp_path, caplog): tbl = create_table() @@ -583,6 +593,185 @@ def test_create_index_accelerator_fallback(tmp_path, caplog): ) +def test_create_index_requires_external_cuvs_backend(tmp_path, monkeypatch): + tbl = create_table() + dataset = lance.write_dataset(tbl, tmp_path) + original_import_module = lance_cuvs.import_module + + def _raise_missing(name): + if name == "lance_cuvs": + raise ModuleNotFoundError("No module named 'lance_cuvs'") + return original_import_module(name) + + monkeypatch.setattr(lance_cuvs, "import_module", _raise_missing) + with pytest.raises( + ModuleNotFoundError, match="requires the external 'lance-cuvs' package" + ): + dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=16, + accelerator="cuvs", + ) + + +class _FakeCuvsTraining: + def __init__(self, ivf_centroids, pq_codebook): + self._ivf_centroids = ivf_centroids + self._pq_codebook = pq_codebook + + def ivf_centroids(self): + return self._ivf_centroids + + def pq_codebook(self): + return self._pq_codebook + + +class _FakeCuvsArtifact: + def __init__(self, artifact_uri, files): + self.artifact_uri = artifact_uri + self.files = files + + +def _make_fake_cuvs_training(num_partitions: int = 4, dimension: int = 128): + centroids = pa.FixedSizeListArray.from_arrays( + pa.array(np.arange(num_partitions * dimension, dtype=np.float32)), + dimension, + ) + codebook = pa.FixedSizeListArray.from_arrays( + pa.array(np.arange(16 * 256 * 8, dtype=np.float32)), + 8, + ) + return _FakeCuvsTraining(centroids, codebook) + + +def test_build_vector_index_on_cuvs_delegates_to_external_backend(tmp_path, monkeypatch): + ds = _make_sample_dataset_base(tmp_path, "prepare_ivf_pq_cuvs_ds", 512, 128) + calls = {} + training = _make_fake_cuvs_training() + + class _FakeBackend: + def train_ivf_pq(self, dataset_uri, column, **kwargs): + calls["train"] = { + "dataset_uri": dataset_uri, + "column": column, + **kwargs, + } + return training + + def build_ivf_pq_artifact(self, dataset_uri, column, **kwargs): + calls["build"] = { + "dataset_uri": dataset_uri, + "column": column, + **kwargs, + } + return _FakeCuvsArtifact( + artifact_uri=str(tmp_path / "artifact"), + files=[str(tmp_path / "artifact" / "data.lance")], + ) + + monkeypatch.setattr(lance_cuvs, "_require_lance_cuvs", lambda: _FakeBackend()) + + artifact_uri, files, ivf_centroids, pq_codebook = ( + lance_cuvs.build_vector_index_on_cuvs( + ds, + "vector", + "l2", + "cuvs", + 4, + 16, + dst_dataset_uri=tmp_path / "artifact_root", + storage_options={"region": "us-east-1"}, + sample_rate=7, + max_iters=20, + num_bits=4, + batch_size=4096, + filter_nan=False, + ) + ) + + assert calls["train"] == { + "dataset_uri": ds.uri, + "column": "vector", + "metric_type": "l2", + "num_partitions": 4, + "num_sub_vectors": 16, + "sample_rate": 7, + "max_iters": 20, + "num_bits": 4, + "filter_nan": False, + "storage_options": {"region": "us-east-1"}, + } + assert calls["build"]["dataset_uri"] == ds.uri + assert calls["build"]["column"] == "vector" + assert calls["build"]["training"] is training + assert calls["build"]["artifact_uri"] == str(tmp_path / "artifact_root") + assert calls["build"]["batch_size"] == 4096 + assert calls["build"]["filter_nan"] is False + assert calls["build"]["storage_options"] == {"region": "us-east-1"} + assert artifact_uri == str(tmp_path / "artifact") + assert files == [str(tmp_path / "artifact" / "data.lance")] + assert ivf_centroids.equals(training.ivf_centroids()) + assert pq_codebook.equals(training.pq_codebook()) + + +def test_prepare_global_ivf_pq_delegates_to_external_cuvs_backend(tmp_path, monkeypatch): + ds = _make_sample_dataset_base(tmp_path, "prepare_ivf_pq_cuvs_ds", 512, 128) + builder = IndicesBuilder(ds, "vector") + training = _make_fake_cuvs_training() + calls = {} + + class _FakeBackend: + def train_ivf_pq(self, dataset_uri, column, **kwargs): + calls["train"] = { + "dataset_uri": dataset_uri, + "column": column, + **kwargs, + } + return training + + monkeypatch.setattr(lance_cuvs, "_require_lance_cuvs", lambda: _FakeBackend()) + + prepared = builder.prepare_global_ivf_pq( + num_partitions=4, + num_subvectors=16, + distance_type="l2", + accelerator="cuvs", + sample_rate=7, + max_iters=20, + ) + + assert calls["train"] == { + "dataset_uri": ds.uri, + "column": "vector", + "metric_type": "l2", + "num_partitions": 4, + "num_sub_vectors": 16, + "sample_rate": 7, + "max_iters": 20, + "num_bits": 8, + "filter_nan": True, + } + assert prepared["ivf_centroids"].equals(training.ivf_centroids()) + assert prepared["pq_codebook"].equals(training.pq_codebook()) + + +def test_create_index_rejects_missing_precomputed_partition_artifact(tmp_path): + dataset = lance.write_dataset(create_table(nvec=64, ndim=128), tmp_path / "artifact_src") + + with pytest.raises(Exception): + dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=16, + ivf_centroids=np.random.randn(4, 128).astype(np.float32), + pq_codebook=np.random.randn(16, 256, 8).astype(np.float32), + precomputed_partition_artifact_uri=str(tmp_path / "missing_artifact"), + ) + + def test_use_index(dataset, tmp_path): ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") ann_ds = ann_ds.create_index( diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 8838b89bc0a..4b058ce8382 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -3612,6 +3612,10 @@ fn prepare_vector_index_params( ivf_params.precomputed_partitions_file = Some(f.to_string()); }; + if let Some(uri) = kwargs.get_item("precomputed_partition_artifact_uri")? { + ivf_params.precomputed_partition_artifact_uri = Some(uri.to_string()); + }; + if let Some(storage_options) = storage_options { ivf_params.storage_options = Some(storage_options); } diff --git a/rust/lance-index/src/vector/ivf/builder.rs b/rust/lance-index/src/vector/ivf/builder.rs index 72e05555441..caccd92d6c4 100644 --- a/rust/lance-index/src/vector/ivf/builder.rs +++ b/rust/lance-index/src/vector/ivf/builder.rs @@ -9,10 +9,9 @@ use std::sync::Arc; use arrow_array::cast::AsArray; use arrow_array::{Array, FixedSizeListArray, UInt32Array, UInt64Array}; use futures::TryStreamExt; -use object_store::path::Path; - use lance_core::error::{Error, Result}; use lance_io::stream::RecordBatchStream; +use object_store::path::Path; /// Parameters to build IVF partitions #[derive(Debug, Clone)] @@ -50,6 +49,10 @@ pub struct IvfBuildParams { /// The input is expected to be (/dir/to/buffers, [buffer1.lance, buffer2.lance, ...]) pub precomputed_shuffle_buffers: Option<(Path, Vec)>, + /// Precomputed partitioned artifact produced by an external backend. + /// Mutually exclusive with other precomputed inputs and requires `centroids` to be set. + pub precomputed_partition_artifact_uri: Option, + pub shuffle_partition_batches: usize, pub shuffle_partition_concurrency: usize, @@ -69,6 +72,7 @@ impl Default for IvfBuildParams { sample_rate: 256, // See faiss precomputed_partitions_file: None, precomputed_shuffle_buffers: None, + precomputed_partition_artifact_uri: None, shuffle_partition_batches: 1024 * 10, shuffle_partition_concurrency: 2, storage_options: None, diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index 0bf714df237..20bed4cdc23 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -532,7 +532,6 @@ impl Shuffler for TwoFileShuffler { offsets_writer.finish().await?; let num_batches = num_batches.load(std::sync::atomic::Ordering::Relaxed); - let total_loss_val = *total_loss.lock().unwrap(); TwoFileShuffleReader::try_new( diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index e137237b9a0..9f2264c6938 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -9,6 +9,7 @@ use std::{any::Any, collections::HashMap}; pub mod builder; pub mod ivf; +mod partition_artifact; pub mod pq; pub mod utils; @@ -34,6 +35,7 @@ use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::builder::recommended_num_partitions; use lance_index::vector::ivf::storage::IvfModel; use object_store::path::Path; +pub use partition_artifact::PartitionArtifactBuilder; use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::pq::ProductQuantizer; @@ -1826,6 +1828,7 @@ fn derive_ivf_params(ivf_model: &IvfModel) -> IvfBuildParams { sample_rate: 256, // Default precomputed_partitions_file: None, precomputed_shuffle_buffers: None, + precomputed_partition_artifact_uri: None, shuffle_partition_batches: 1024 * 10, // Default shuffle_partition_concurrency: 2, // Default storage_options: None, diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 8354f80416c..0edfbea4812 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -67,8 +67,9 @@ use lance_index::{ MIN_PARTITION_SIZE_PERCENT, }; use lance_io::local::to_local_path; +use lance_io::object_store::ObjectStore; use lance_io::stream::RecordBatchStream; -use lance_io::{object_store::ObjectStore, stream::RecordBatchStreamAdapter}; +use lance_io::stream::RecordBatchStreamAdapter; use lance_linalg::distance::{DistanceType, Dot, L2, Normalize}; use lance_linalg::kernels::normalize_fsl; use log::info; @@ -85,6 +86,7 @@ use crate::index::vector::utils::infer_vector_dim; use super::v2::IVFIndex; use super::{ ivf::load_precomputed_partitions_if_available, + partition_artifact::PartitionArtifactShuffleReader, utils::{self, get_vector_type}, }; @@ -141,6 +143,19 @@ type BuildStream = Pin::Storage, S, f64)>>> + Send>>; impl IvfIndexBuilder { + async fn try_open_precomputed_partition_artifact_reader( + &self, + uri: &str, + ) -> Result> { + let storage_options = self + .ivf_params + .as_ref() + .and_then(|params| params.storage_options.as_ref()); + Ok(Arc::new( + PartitionArtifactShuffleReader::try_open(uri, storage_options).await?, + )) + } + #[allow(clippy::too_many_arguments)] pub fn new( dataset: Dataset, @@ -527,6 +542,19 @@ impl IvfIndexBuilder return Err(Error::invalid_input("dataset not set before shuffling")); }; + if let Some(uri) = self + .ivf_params + .as_ref() + .and_then(|params| params.precomputed_partition_artifact_uri.as_deref()) + { + log::info!("shuffle with precomputed partition artifact from {}", uri); + self.shuffle_reader = Some( + self.try_open_precomputed_partition_artifact_reader(uri) + .await?, + ); + return Ok(()); + } + let stream = match self .ivf_params .as_ref() @@ -534,8 +562,6 @@ impl IvfIndexBuilder { Some((uri, _)) => { let uri = to_local_path(uri); - // the uri points to data directory, - // so need to trim the "data" suffix for reading the dataset let uri = uri.trim_end_matches("data"); log::info!("shuffle with precomputed shuffle buffers from {}", uri); let ds = Dataset::open(uri).await?; diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 9f226a42db5..34ce23f1eac 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -1293,6 +1293,12 @@ fn sanity_check_ivf_params(ivf: &IvfBuildParams) -> Result<()> { )); } + if ivf.precomputed_partition_artifact_uri.is_some() && ivf.centroids.is_none() { + return Err(Error::index( + "precomputed_partition_artifact_uri requires centroids to be set".to_string(), + )); + } + if ivf.precomputed_shuffle_buffers.is_some() && ivf.precomputed_partitions_file.is_some() { return Err(Error::index( "precomputed_shuffle_buffers and precomputed_partitions_file are mutually exclusive" @@ -1300,6 +1306,22 @@ fn sanity_check_ivf_params(ivf: &IvfBuildParams) -> Result<()> { )); } + if ivf.precomputed_partition_artifact_uri.is_some() && ivf.precomputed_partitions_file.is_some() + { + return Err(Error::index( + "precomputed_partition_artifact_uri and precomputed_partitions_file are mutually exclusive" + .to_string(), + )); + } + + if ivf.precomputed_partition_artifact_uri.is_some() && ivf.precomputed_shuffle_buffers.is_some() + { + return Err(Error::index( + "precomputed_partition_artifact_uri and precomputed_shuffle_buffers are mutually exclusive" + .to_string(), + )); + } + Ok(()) } @@ -1311,6 +1333,12 @@ fn sanity_check_params(ivf: &IvfBuildParams, pq: &PQBuildParams) -> Result<()> { )); } + if ivf.precomputed_partition_artifact_uri.is_some() && pq.codebook.is_none() { + return Err(Error::index( + "precomputed_partition_artifact_uri requires codebooks to be set".to_string(), + )); + } + Ok(()) } diff --git a/rust/lance/src/index/vector/partition_artifact.rs b/rust/lance/src/index/vector/partition_artifact.rs new file mode 100644 index 00000000000..a721437358d --- /dev/null +++ b/rust/lance/src/index/vector/partition_artifact.rs @@ -0,0 +1,1008 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::collections::HashMap; +use std::mem; +use std::ops::Range; +use std::sync::{Arc, Mutex}; + +use arrow_array::cast::AsArray; +use arrow_array::{FixedSizeListArray, RecordBatch, UInt8Array, UInt64Array}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use lance_arrow::FixedSizeListArrayExt; +use lance_core::cache::LanceCache; +use lance_core::datatypes::Schema; +use lance_core::{Error, ROW_ID, Result}; +use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; +use lance_file::reader::{FileReader, FileReaderOptions}; +use lance_file::version::LanceFileVersion; +use lance_file::writer::{FileWriter, FileWriterOptions}; +use lance_index::vector::v3::shuffler::ShuffleReader; +use lance_index::vector::{PART_ID_COLUMN, PQ_CODE_COLUMN}; +use lance_io::ReadBatchParams; +use lance_io::object_store::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry}; +use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; +use lance_io::stream::{RecordBatchStream, RecordBatchStreamAdapter}; +use lance_io::traits::Writer; +use lance_io::utils::CachedFileSize; +use object_store::path::Path; +use serde::{Deserialize, Serialize}; +use tokio::io::AsyncWriteExt; + +const PARTITION_ARTIFACT_MANIFEST_VERSION: u32 = 1; +const PARTITION_ARTIFACT_MANIFEST_FILE_NAME: &str = "manifest.json"; +const PARTITION_ARTIFACT_PARTITIONS_DIR: &str = "partitions"; +const PARTITION_ARTIFACT_DEFAULT_BUCKETS: usize = 256; +const PARTITION_ARTIFACT_BUCKET_PREFIX: &str = "bucket-"; +const PARTITION_ARTIFACT_FILE_VERSION: &str = "2.2"; +const PARTITION_ARTIFACT_BUCKET_BUFFER_ROWS: usize = 32 * 1024; + +/// Top-level manifest for a precomputed partition artifact. +/// +/// The manifest is intentionally small and JSON-encoded so an external backend +/// can materialize partition data once and Lance can reopen it later without +/// understanding any backend-specific details. +#[derive(Debug, Serialize, Deserialize)] +struct PartitionArtifactManifest { + version: u32, + num_partitions: usize, + #[serde(default)] + metadata_file: Option, + #[serde(default)] + total_loss: Option, + partitions: Vec, +} + +/// Describes where one logical IVF partition lives inside the artifact. +/// +/// Multiple logical partitions can share the same physical file when they hash +/// to the same bucket. `ranges` records the row spans within that file that +/// belong to this partition. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PartitionArtifactPartition { + #[serde(default)] + path: Option, + #[serde(default)] + num_rows: usize, + #[serde(default)] + ranges: Vec, +} + +/// A contiguous row range for a partition inside one bucket file. +/// +/// The builder sorts each finalized bucket by partition id, so a partition is +/// usually represented by a single range. The type still allows multiple runs +/// so the reader does not depend on that implementation detail. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PartitionArtifactRange { + offset: u64, + num_rows: u64, +} + +/// In-memory staging buffer for one bucket before it is flushed to disk. +/// +/// Batches arrive grouped arbitrarily by the backend. The builder first +/// appends rows into per-bucket buffers so it can write larger sequential runs +/// to temporary files instead of issuing tiny file writes. +#[derive(Default, Debug)] +struct BucketBuffer { + row_ids: Vec, + partition_ids: Vec, + pq_values: Vec, +} + +impl BucketBuffer { + /// Number of staged rows currently buffered for this bucket. + fn len(&self) -> usize { + self.row_ids.len() + } + + /// Whether the bucket currently has any staged rows. + fn is_empty(&self) -> bool { + self.row_ids.is_empty() + } +} + +/// Writes partition-addressable encoded rows for a later Lance finalization. +/// +/// The builder uses bucket-local buffering to keep append-time memory bounded. +/// Each flush sorts only the current in-memory bucket and appends it directly to +/// the finalized bucket file, while the manifest accumulates per-partition row +/// ranges. This keeps the writer streaming and avoids a full read/sort/rewrite +/// pass at `finish()` time. +pub struct PartitionArtifactBuilder { + object_store: Arc, + root_dir: Path, + num_partitions: usize, + num_buckets: usize, + pq_code_width: usize, + final_schema: Arc, + final_writers: Vec>, + buffers: Vec, + partitions: Vec, + bucket_row_counts: Vec, +} + +impl PartitionArtifactBuilder { + /// Create a builder from a URI and optional storage options. + /// + /// This is the external entry point used by backends that only know an + /// artifact URI. It resolves the object store and then delegates to the + /// store-aware constructor. + pub async fn try_new( + uri: &str, + num_partitions: usize, + pq_code_width: usize, + storage_options: Option<&HashMap>, + ) -> Result { + let registry = Arc::new(ObjectStoreRegistry::default()); + let params = if let Some(storage_options) = storage_options { + ObjectStoreParams { + storage_options_accessor: Some(Arc::new( + lance_io::object_store::StorageOptionsAccessor::with_static_options( + storage_options.clone(), + ), + )), + ..Default::default() + } + } else { + ObjectStoreParams::default() + }; + let (object_store, root_dir) = + ObjectStore::from_uri_and_params(registry, uri, ¶ms).await?; + Self::try_new_with_store(object_store, root_dir, num_partitions, pq_code_width) + } + + /// Create a builder against an already-resolved object store. + /// + /// The builder precomputes the final schema and allocates one staging + /// buffer per bucket. Buckets are a write-time sharding scheme: they are + /// not visible to readers, but they keep memory usage bounded and avoid one + /// file per partition. + pub fn try_new_with_store( + object_store: Arc, + root_dir: Path, + num_partitions: usize, + pq_code_width: usize, + ) -> Result { + if num_partitions == 0 { + return Err(Error::invalid_input( + "partition artifact builder requires num_partitions > 0".to_string(), + )); + } + if pq_code_width == 0 { + return Err(Error::invalid_input( + "partition artifact builder requires pq_code_width > 0".to_string(), + )); + } + + let num_buckets = num_partitions + .min(PARTITION_ARTIFACT_DEFAULT_BUCKETS) + .max(1); + let final_schema = Arc::new(ArrowSchema::new(vec![ + Field::new(ROW_ID, DataType::UInt64, false), + Field::new( + PQ_CODE_COLUMN, + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::UInt8, true)), + pq_code_width as i32, + ), + true, + ), + ])); + + Ok(Self { + object_store, + root_dir, + num_partitions, + num_buckets, + pq_code_width, + final_schema, + final_writers: (0..num_buckets).map(|_| None).collect(), + buffers: (0..num_buckets).map(|_| BucketBuffer::default()).collect(), + partitions: vec![ + PartitionArtifactPartition { + path: None, + num_rows: 0, + ranges: Vec::new(), + }; + num_partitions + ], + bucket_row_counts: vec![0; num_buckets], + }) + } + + /// Append one encoded batch into the artifact staging area. + /// + /// Input batches must already contain row ids, partition ids, and PQ codes. + /// Rows are redistributed into bucket-local in-memory buffers and flushed to + /// temporary files once they become large enough. + pub async fn append_batch(&mut self, batch: &RecordBatch) -> Result<()> { + validate_input_batch(batch, self.pq_code_width)?; + + let row_ids = batch[ROW_ID].as_primitive::(); + let part_ids = batch[PART_ID_COLUMN].as_primitive::(); + let pq_codes = batch[PQ_CODE_COLUMN].as_fixed_size_list(); + let pq_values = pq_codes + .values() + .as_primitive::(); + let pq_values = pq_values.values().as_ref(); + + for row_idx in 0..batch.num_rows() { + let partition_id = part_ids.value(row_idx) as usize; + if partition_id >= self.num_partitions { + return Err(Error::invalid_input(format!( + "partition artifact batch contains partition id {} but num_partitions is {}", + partition_id, self.num_partitions + ))); + } + let bucket_id = partition_id % self.num_buckets; + let buffer = &mut self.buffers[bucket_id]; + buffer.row_ids.push(row_ids.value(row_idx)); + buffer.partition_ids.push(partition_id as u32); + let start = row_idx * self.pq_code_width; + let end = start + self.pq_code_width; + buffer.pq_values.extend_from_slice(&pq_values[start..end]); + if buffer.len() >= PARTITION_ARTIFACT_BUCKET_BUFFER_ROWS { + self.flush_bucket(bucket_id).await?; + } + } + Ok(()) + } + + /// Finalize the artifact and return the relative files that were created. + /// + /// Finalization only needs to flush the remaining in-memory buffers and + /// persist the manifest because bucket files are already in their final + /// layout. + pub async fn finish( + &mut self, + metadata_file: &str, + total_loss: Option, + ) -> Result> { + for bucket_id in 0..self.num_buckets { + self.flush_bucket(bucket_id).await?; + } + for writer in self.final_writers.iter_mut() { + if let Some(writer) = writer.as_mut() { + writer.finish().await?; + } + } + + let mut artifact_files = Vec::with_capacity(self.num_buckets + 1); + for bucket_id in 0..self.num_buckets { + if self.final_writers[bucket_id].is_some() { + artifact_files.push(self.final_bucket_relative_path(bucket_id)); + } + } + + let manifest = PartitionArtifactManifest { + version: PARTITION_ARTIFACT_MANIFEST_VERSION, + num_partitions: self.num_partitions, + metadata_file: Some(metadata_file.to_string()), + total_loss, + partitions: self.partitions.clone(), + }; + write_json( + self.object_store.as_ref(), + &self.root_dir.child(PARTITION_ARTIFACT_MANIFEST_FILE_NAME), + &manifest, + ) + .await?; + + let mut files = vec![PARTITION_ARTIFACT_MANIFEST_FILE_NAME.to_string()]; + files.extend(artifact_files); + Ok(files) + } + + /// Flush the current in-memory buffer for one bucket into its finalized + /// bucket file. + /// + /// Each flush sorts only the buffered rows for this bucket and appends them + /// to the final file while recording new manifest ranges for the affected + /// partitions. + async fn flush_bucket(&mut self, bucket_id: usize) -> Result<()> { + if self.buffers[bucket_id].is_empty() { + return Ok(()); + } + + let buffer = &mut self.buffers[bucket_id]; + let row_ids = UInt64Array::from(mem::take(&mut buffer.row_ids)); + let part_ids = mem::take(&mut buffer.partition_ids); + let pq_values = UInt8Array::from(mem::take(&mut buffer.pq_values)); + let total_rows = row_ids.len(); + + let mut permutation = (0..total_rows).collect::>(); + permutation.sort_unstable_by_key(|&idx| part_ids[idx]); + + let mut sorted_row_ids = Vec::with_capacity(total_rows); + let mut sorted_partition_ids = Vec::with_capacity(total_rows); + let mut sorted_pq_values = Vec::with_capacity(total_rows * self.pq_code_width); + for idx in permutation { + sorted_row_ids.push(row_ids.value(idx)); + sorted_partition_ids.push(part_ids[idx]); + let start = idx * self.pq_code_width; + let end = start + self.pq_code_width; + sorted_pq_values.extend_from_slice(&pq_values.values()[start..end]); + } + + let file_offset = self.bucket_row_counts[bucket_id]; + let final_relative_path = self.final_bucket_relative_path(bucket_id); + let mut offset = 0usize; + while offset < sorted_partition_ids.len() { + let partition_id = sorted_partition_ids[offset] as usize; + let mut end = offset + 1; + while end < sorted_partition_ids.len() + && sorted_partition_ids[end] == sorted_partition_ids[offset] + { + end += 1; + } + let partition = &mut self.partitions[partition_id]; + match &partition.path { + Some(existing) if existing != &final_relative_path => { + return Err(Error::io(format!( + "partition {} is split across multiple bucket files: '{}' vs '{}'", + partition_id, existing, final_relative_path + ))); + } + None => partition.path = Some(final_relative_path.clone()), + _ => {} + } + partition.num_rows += end - offset; + partition.ranges.push(PartitionArtifactRange { + offset: file_offset + offset as u64, + num_rows: (end - offset) as u64, + }); + offset = end; + } + + let pq_codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from(sorted_pq_values), + self.pq_code_width as i32, + )?; + let final_batch = RecordBatch::try_new( + self.final_schema.clone(), + vec![ + Arc::new(UInt64Array::from(sorted_row_ids)), + Arc::new(pq_codes), + ], + )?; + let writer = self.ensure_final_writer(bucket_id).await?; + writer.write_batch(&final_batch).await?; + self.bucket_row_counts[bucket_id] += total_rows as u64; + Ok(()) + } + + /// Lazily create the finalized writer for a bucket. + /// + /// Buckets that never receive rows never create a file, which keeps sparse + /// artifacts compact. + async fn ensure_final_writer(&mut self, bucket_id: usize) -> Result<&mut FileWriter> { + if self.final_writers[bucket_id].is_none() { + let path = self.final_bucket_path(bucket_id); + let writer = FileWriter::try_new( + self.object_store.create(&path).await?, + Schema::try_from(self.final_schema.as_ref())?, + file_writer_options()?, + )?; + self.final_writers[bucket_id] = Some(writer); + } + Ok(self.final_writers[bucket_id] + .as_mut() + .expect("final writer initialized")) + } + + /// Path of the finalized file for one bucket. + fn final_bucket_path(&self, bucket_id: usize) -> Path { + self.root_dir + .child(PARTITION_ARTIFACT_PARTITIONS_DIR) + .child(format!( + "{PARTITION_ARTIFACT_BUCKET_PREFIX}{bucket_id:05}.lance" + )) + } + + /// Relative path recorded in the manifest for one finalized bucket. + fn final_bucket_relative_path(&self, bucket_id: usize) -> String { + format!( + "{PARTITION_ARTIFACT_PARTITIONS_DIR}/{PARTITION_ARTIFACT_BUCKET_PREFIX}{bucket_id:05}.lance" + ) + } +} + +/// Reopens a partition artifact as a `ShuffleReader`. +/// +/// The final Lance builder consumes artifacts through the generic +/// [`ShuffleReader`] interface, so this adapter hides the manifest parsing and +/// file caching needed to expose partition-local record batch streams. +#[derive(Debug)] +pub(crate) struct PartitionArtifactShuffleReader { + scheduler: Arc, + root_dir: Path, + partitions: Vec, + total_loss: Option, + file_readers: Mutex>>, +} + +/// Writer options for all files stored inside a partition artifact. +/// +/// The artifact uses a fixed file version so external backends and Lance +/// finalization agree on the on-disk layout. +fn file_writer_options() -> Result { + Ok(FileWriterOptions { + format_version: Some( + PARTITION_ARTIFACT_FILE_VERSION + .parse::() + .map_err(|error| { + Error::invalid_input(format!( + "invalid partition artifact file version '{}': {}", + PARTITION_ARTIFACT_FILE_VERSION, error + )) + })?, + ), + ..Default::default() + }) +} + +/// Validate that a backend-produced batch matches the artifact contract. +/// +/// The builder is intentionally strict here because any schema drift would only +/// surface much later during finalization. +fn validate_input_batch(batch: &RecordBatch, pq_code_width: usize) -> Result<()> { + let Some(row_ids) = batch.column_by_name(ROW_ID) else { + return Err(Error::invalid_input(format!( + "partition artifact batch must contain {ROW_ID}" + ))); + }; + if row_ids.data_type() != &DataType::UInt64 { + return Err(Error::invalid_input(format!( + "partition artifact batch column {ROW_ID} must be uint64, got {}", + row_ids.data_type() + ))); + } + let Some(part_ids) = batch.column_by_name(PART_ID_COLUMN) else { + return Err(Error::invalid_input(format!( + "partition artifact batch must contain {PART_ID_COLUMN}" + ))); + }; + if part_ids.data_type() != &DataType::UInt32 { + return Err(Error::invalid_input(format!( + "partition artifact batch column {PART_ID_COLUMN} must be uint32, got {}", + part_ids.data_type() + ))); + } + let Some(pq_codes) = batch.column_by_name(PQ_CODE_COLUMN) else { + return Err(Error::invalid_input(format!( + "partition artifact batch must contain {PQ_CODE_COLUMN}" + ))); + }; + match pq_codes.data_type() { + DataType::FixedSizeList(_, width) if *width as usize == pq_code_width => Ok(()), + other => Err(Error::invalid_input(format!( + "partition artifact batch column {PQ_CODE_COLUMN} must be fixed_size_list[{}], got {}", + pq_code_width, other + ))), + } +} + +/// Serialize a small JSON sidecar directly into the object store. +async fn write_json( + object_store: &ObjectStore, + path: &Path, + value: &T, +) -> Result<()> { + let bytes = serde_json::to_vec(value).map_err(|error| { + Error::invalid_input(format!( + "failed to serialize partition artifact manifest '{}': {}", + path, error + )) + })?; + let mut writer = object_store.create(path).await?; + writer.write_all(&bytes).await?; + Writer::shutdown(writer.as_mut()).await?; + Ok(()) +} + +impl PartitionArtifactShuffleReader { + /// Open an artifact reader from a URI and optional storage options. + pub(crate) async fn try_open( + uri: &str, + storage_options: Option<&HashMap>, + ) -> Result { + let registry = Arc::new(ObjectStoreRegistry::default()); + let params = if let Some(storage_options) = storage_options { + ObjectStoreParams { + storage_options_accessor: Some(Arc::new( + lance_io::object_store::StorageOptionsAccessor::with_static_options( + storage_options.clone(), + ), + )), + ..Default::default() + } + } else { + ObjectStoreParams::default() + }; + let (object_store, root_dir) = + ObjectStore::from_uri_and_params(registry, uri, ¶ms).await?; + Self::try_open_with_store(object_store, root_dir).await + } + + /// Open an artifact reader once the object store has already been resolved. + /// + /// This reads the manifest once, validates it, and initializes the shared + /// scheduler and reader cache used by partition reads. + async fn try_open_with_store(object_store: Arc, root_dir: Path) -> Result { + let manifest_path = root_dir.child("manifest.json"); + let manifest_bytes = object_store.read_one_all(&manifest_path).await?; + let manifest: PartitionArtifactManifest = + serde_json::from_slice(&manifest_bytes).map_err(|error| { + Error::invalid_input(format!( + "failed to parse partition artifact manifest '{}': {}", + manifest_path, error + )) + })?; + if manifest.version != 1 { + return Err(Error::invalid_input(format!( + "unsupported partition artifact manifest version {}", + manifest.version + ))); + } + if manifest.partitions.len() != manifest.num_partitions { + return Err(Error::invalid_input(format!( + "partition artifact manifest has {} partitions but num_partitions is {}", + manifest.partitions.len(), + manifest.num_partitions + ))); + } + + let scheduler = ScanScheduler::new( + object_store.clone(), + SchedulerConfig::max_bandwidth(&object_store), + ); + Ok(Self { + scheduler, + root_dir, + partitions: manifest.partitions, + total_loss: manifest.total_loss, + file_readers: Mutex::new(HashMap::new()), + }) + } + + /// Open and cache a file reader for a finalized bucket file. + /// + /// Multiple logical partitions can point at the same bucket file, so the + /// reader cache prevents redundant file opens during finalization. + async fn open_file_reader(&self, relative_path: &str) -> Result> { + if let Some(reader) = self + .file_readers + .lock() + .expect("partition artifact reader mutex poisoned") + .get(relative_path) + .cloned() + { + return Ok(reader); + } + + let path = join_relative_path(&self.root_dir, relative_path); + let reader = Arc::new( + FileReader::try_open( + self.scheduler + .open_file(&path, &CachedFileSize::unknown()) + .await?, + None, + Arc::::default(), + &LanceCache::no_cache(), + FileReaderOptions::default(), + ) + .await?, + ); + self.file_readers + .lock() + .expect("partition artifact reader mutex poisoned") + .insert(relative_path.to_string(), reader.clone()); + Ok(reader) + } +} + +/// Join a manifest-relative path onto the artifact root. +fn join_relative_path(root_dir: &Path, relative_path: &str) -> Path { + relative_path + .split('/') + .filter(|segment| !segment.is_empty()) + .fold(root_dir.clone(), |path, segment| path.child(segment)) +} + +#[async_trait::async_trait] +impl ShuffleReader for PartitionArtifactShuffleReader { + /// Return a stream over all rows belonging to one logical partition. + /// + /// The manifest already records the precise row ranges for each partition, + /// so the reader can issue targeted range reads without scanning unrelated + /// partitions. + async fn read_partition( + &self, + partition_id: usize, + ) -> Result>> { + let Some(partition) = self.partitions.get(partition_id) else { + return Ok(None); + }; + if partition.num_rows == 0 { + return Ok(None); + } + let path = partition.path.as_ref().ok_or_else(|| { + Error::invalid_input(format!( + "partition artifact partition {} has {} rows but no path", + partition_id, partition.num_rows + )) + })?; + if partition.ranges.is_empty() { + return Err(Error::invalid_input(format!( + "partition artifact partition {} has {} rows but no ranges", + partition_id, partition.num_rows + ))); + } + + let reader = self.open_file_reader(path).await?; + let ranges = partition + .ranges + .iter() + .map(|range| Range { + start: range.offset, + end: range.offset + range.num_rows, + }) + .collect::>(); + let schema = Arc::new(reader.schema().as_ref().into()); + Ok(Some(Box::new(RecordBatchStreamAdapter::new( + schema, + reader.read_stream( + ReadBatchParams::Ranges(ranges.into()), + u32::MAX, + 16, + FilterExpression::no_filter(), + )?, + )))) + } + + /// Number of encoded rows available for one logical partition. + fn partition_size(&self, partition_id: usize) -> Result { + Ok(self + .partitions + .get(partition_id) + .map(|partition| partition.num_rows) + .unwrap_or(0)) + } + + /// Optional training loss propagated from the backend into the artifact. + fn total_loss(&self) -> Option { + self.total_loss + } +} + +#[cfg(test)] +mod tests { + use std::fs; + + use arrow_array::cast::AsArray; + use arrow_array::{FixedSizeListArray, RecordBatch, UInt8Array, UInt64Array}; + use futures::TryStreamExt; + use lance_arrow::FixedSizeListArrayExt; + use lance_core::ROW_ID; + use lance_core::datatypes::Schema; + use lance_file::writer::{FileWriter, FileWriterOptions}; + use lance_io::object_store::ObjectStore; + + use crate::Error; + + use super::*; + + #[tokio::test] + async fn partition_artifact_builder_compacts_runs_into_single_partition_range() { + let tempdir = tempfile::tempdir().unwrap(); + let root_dir = tempdir.path().join("artifact"); + fs::create_dir_all(&root_dir).unwrap(); + let object_store = Arc::new(ObjectStore::local()); + let root_path = Path::from_filesystem_path(&root_dir).unwrap(); + + let mut builder = PartitionArtifactBuilder::try_new_with_store( + object_store.clone(), + root_path.clone(), + 300, + 2, + ) + .unwrap(); + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new(ROW_ID, DataType::UInt64, false), + Field::new(PART_ID_COLUMN, DataType::UInt32, false), + Field::new( + PQ_CODE_COLUMN, + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::UInt8, true)), 2), + true, + ), + ])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt64Array::from(vec![10_u64, 11, 12, 13])), + Arc::new(UInt32Array::from(vec![0_u32, 256, 0, 256])), + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]), + 2, + ) + .unwrap(), + ), + ], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + schema, + vec![ + Arc::new(UInt64Array::from(vec![14_u64, 15])), + Arc::new(UInt32Array::from(vec![1_u32, 256])), + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(vec![9, 10, 11, 12]), + 2, + ) + .unwrap(), + ), + ], + ) + .unwrap(); + builder.append_batch(&batch1).await.unwrap(); + builder.append_batch(&batch2).await.unwrap(); + let artifact_files = builder.finish("metadata.lance", Some(2.5)).await.unwrap(); + assert_eq!(artifact_files[0], "manifest.json"); + assert!( + artifact_files + .iter() + .any(|path| path.ends_with("bucket-00000.lance")) + ); + + let manifest: PartitionArtifactManifest = + serde_json::from_slice(&fs::read(root_dir.join("manifest.json")).unwrap()).unwrap(); + assert_eq!(manifest.version, 1); + assert_eq!(manifest.metadata_file.as_deref(), Some("metadata.lance")); + assert_eq!(manifest.total_loss, Some(2.5)); + assert_eq!(manifest.partitions[0].num_rows, 2); + assert_eq!(manifest.partitions[0].ranges.len(), 1); + assert_eq!(manifest.partitions[1].num_rows, 1); + assert_eq!(manifest.partitions[1].ranges.len(), 1); + assert_eq!(manifest.partitions[256].num_rows, 3); + assert_eq!(manifest.partitions[256].ranges.len(), 1); + assert_eq!( + manifest.partitions[0].path, manifest.partitions[256].path, + "partitions sharing a bucket should share one final file" + ); + + let reader = PartitionArtifactShuffleReader::try_open_with_store(object_store, root_path) + .await + .unwrap(); + let partition_0 = reader + .read_partition(0) + .await + .unwrap() + .unwrap() + .try_collect::>() + .await + .unwrap(); + let partition_0_row_ids = partition_0 + .iter() + .flat_map(|batch| { + batch[ROW_ID] + .as_primitive::() + .values() + .iter() + .copied() + }) + .collect::>(); + assert_eq!(partition_0_row_ids, vec![10, 12]); + + let partition_256 = reader + .read_partition(256) + .await + .unwrap() + .unwrap() + .try_collect::>() + .await + .unwrap(); + let partition_256_row_ids = partition_256 + .iter() + .flat_map(|batch| { + batch[ROW_ID] + .as_primitive::() + .values() + .iter() + .copied() + }) + .collect::>(); + assert_eq!(partition_256_row_ids, vec![11, 13, 15]); + } + + #[tokio::test] + async fn partition_artifact_reader_reads_partition_ranges() { + let tempdir = tempfile::tempdir().unwrap(); + let root_dir = tempdir.path().join("artifact"); + fs::create_dir_all(root_dir.join("partitions")).unwrap(); + + let object_store = Arc::new(ObjectStore::local()); + let root_path = Path::from_filesystem_path(&root_dir).unwrap(); + let partition_path = root_path.child("partitions").child("bucket-00000.lance"); + let schema = Arc::new(arrow_schema::Schema::new(vec![ + arrow_schema::Field::new(ROW_ID, arrow_schema::DataType::UInt64, false), + arrow_schema::Field::new( + lance_index::vector::PQ_CODE_COLUMN, + arrow_schema::DataType::FixedSizeList( + Arc::new(arrow_schema::Field::new( + "item", + arrow_schema::DataType::UInt8, + true, + )), + 2, + ), + true, + ), + ])); + let mut writer = FileWriter::try_new( + object_store.create(&partition_path).await.unwrap(), + Schema::try_from(schema.as_ref()).unwrap(), + FileWriterOptions::default(), + ) + .unwrap(); + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt64Array::from(vec![10_u64, 11, 12])), + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(vec![1, 2, 3, 4, 5, 6]), + 2, + ) + .unwrap(), + ), + ], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt64Array::from(vec![13_u64, 14])), + Arc::new( + FixedSizeListArray::try_new_from_values(UInt8Array::from(vec![7, 8, 9, 10]), 2) + .unwrap(), + ), + ], + ) + .unwrap(); + writer.write_batch(&batch1).await.unwrap(); + writer.write_batch(&batch2).await.unwrap(); + writer.finish().await.unwrap(); + + let manifest = serde_json::json!({ + "version": 1, + "num_partitions": 3, + "total_loss": 1.5, + "partitions": [ + { + "path": "partitions/bucket-00000.lance", + "num_rows": 2, + "ranges": [ + {"offset": 0, "num_rows": 1}, + {"offset": 3, "num_rows": 1}, + ], + }, + { + "path": "partitions/bucket-00000.lance", + "num_rows": 2, + "ranges": [ + {"offset": 1, "num_rows": 2}, + ], + }, + { + "num_rows": 0, + "ranges": [], + }, + ], + }); + fs::write( + root_dir.join("manifest.json"), + serde_json::to_vec(&manifest).unwrap(), + ) + .unwrap(); + + let reader = PartitionArtifactShuffleReader::try_open_with_store(object_store, root_path) + .await + .unwrap(); + assert_eq!(reader.partition_size(0).unwrap(), 2); + assert_eq!(reader.partition_size(1).unwrap(), 2); + assert_eq!(reader.partition_size(2).unwrap(), 0); + assert_eq!(reader.total_loss(), Some(1.5)); + + let stream = reader.read_partition(0).await.unwrap().unwrap(); + let batches = stream.try_collect::>().await.unwrap(); + let row_ids = batches + .iter() + .flat_map(|batch| { + batch[ROW_ID] + .as_primitive::() + .values() + .iter() + .copied() + }) + .collect::>(); + assert_eq!(row_ids, vec![10, 13]); + assert!(reader.read_partition(2).await.unwrap().is_none()); + } + + #[tokio::test] + async fn partition_artifact_reader_rejects_missing_partition_entry() { + let tempdir = tempfile::tempdir().unwrap(); + let root_dir = tempdir.path().join("artifact"); + fs::create_dir_all(&root_dir).unwrap(); + let manifest = serde_json::json!({ + "version": 1, + "num_partitions": 2, + "partitions": [{"num_rows": 0, "ranges": []}], + }); + fs::write( + root_dir.join("manifest.json"), + serde_json::to_vec(&manifest).unwrap(), + ) + .unwrap(); + + let error = PartitionArtifactShuffleReader::try_open_with_store( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(&root_dir).unwrap(), + ) + .await + .unwrap_err(); + assert!(matches!(error, Error::InvalidInput { .. })); + } + + #[tokio::test] + async fn partition_artifact_builder_records_multiple_ranges_for_repeated_flushes() { + let tempdir = tempfile::tempdir().unwrap(); + let root_dir = tempdir.path().join("artifact"); + fs::create_dir_all(&root_dir).unwrap(); + let object_store = Arc::new(ObjectStore::local()); + let root_path = Path::from_filesystem_path(&root_dir).unwrap(); + + let mut builder = + PartitionArtifactBuilder::try_new_with_store(object_store, root_path, 4, 2).unwrap(); + let num_rows = PARTITION_ARTIFACT_BUCKET_BUFFER_ROWS + 1024; + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new(ROW_ID, DataType::UInt64, false), + Field::new(PART_ID_COLUMN, DataType::UInt32, false), + Field::new( + PQ_CODE_COLUMN, + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::UInt8, true)), 2), + true, + ), + ])); + let row_ids = UInt64Array::from_iter_values((0..num_rows as u64).into_iter()); + let part_ids = UInt32Array::from_iter_values((0..num_rows).map(|_| 0_u32)); + let pq_values = UInt8Array::from_iter_values((0..num_rows * 2).map(|v| (v % 251) as u8)); + let pq_codes = FixedSizeListArray::try_new_from_values(pq_values, 2).unwrap(); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(row_ids), Arc::new(part_ids), Arc::new(pq_codes)], + ) + .unwrap(); + + builder.append_batch(&batch).await.unwrap(); + builder.finish("metadata.lance", None).await.unwrap(); + + let manifest: PartitionArtifactManifest = + serde_json::from_slice(&fs::read(root_dir.join("manifest.json")).unwrap()).unwrap(); + assert_eq!(manifest.partitions[0].num_rows, num_rows); + assert_eq!(manifest.partitions[0].ranges.len(), 2); + assert_eq!( + manifest.partitions[0].ranges[0].num_rows, + PARTITION_ARTIFACT_BUCKET_BUFFER_ROWS as u64 + ); + assert_eq!( + manifest.partitions[0].ranges[1].offset, + PARTITION_ARTIFACT_BUCKET_BUFFER_ROWS as u64 + ); + } +}