diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index 2f98efa51e2..10d054e863a 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -3771,6 +3771,7 @@ dependencies = [ "env_logger", "jni", "lance", + "lance-arrow", "lance-core", "lance-datafusion", "lance-encoding", diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 5a1e7f3f655..5b0609ce1ea 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -17,6 +17,7 @@ default = [] [dependencies] lance = { path = "../../rust/lance", features = ["substrait"] } +lance-arrow = { path = "../../rust/lance-arrow" } lance-datafusion = { path = "../../rust/lance-datafusion" } lance-encoding = { path = "../../rust/lance-encoding" } lance-linalg = { path = "../../rust/lance-linalg" } diff --git a/java/lance-jni/src/vector_trainer.rs b/java/lance-jni/src/vector_trainer.rs index 9ea164d3586..e449b769474 100755 --- a/java/lance-jni/src/vector_trainer.rs +++ b/java/lance-jni/src/vector_trainer.rs @@ -14,7 +14,9 @@ use jni::objects::{JClass, JFloatArray, JObject, JString}; use jni::sys::jfloatArray; use lance::index::NoopIndexBuildProgress; use lance::index::vector::utils::get_vector_dim; +use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::ivf::builder::IvfBuildParams as RustIvfBuildParams; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::builder::PQBuildParams as RustPQBuildParams; use lance_linalg::distance::MetricType; @@ -76,18 +78,58 @@ fn build_pq_params_from_java( }) } +/// Extract a nullable Java `List` into `Option>`. +fn get_nullable_fragment_ids(env: &mut JNIEnv, obj: &JObject) -> Result>> { + if obj.is_null() { + return Ok(None); + } + let ints = env.get_integers(obj)?; + Ok(Some(ints.into_iter().map(|i| i as u32).collect())) +} + +/// Extract a nullable Java `float[]` into `Option>`. +fn get_nullable_ivf_centroids( + env: &mut JNIEnv, + centroids_obj: &JObject, +) -> Result>> { + if centroids_obj.is_null() { + return Ok(None); + } + let jarray = unsafe { JFloatArray::from_raw(centroids_obj.as_raw()) }; + let length = env.get_array_length(&jarray)?; + let mut buffer = vec![0.0f32; length as usize]; + env.get_float_array_region(&jarray, 0, &mut buffer)?; + Ok(Some(buffer)) +} + +/// Build an `IvfModel` from a flat float array and known dimension. +fn build_ivf_model_from_centroids(centroids: &[f32], dim: usize) -> Result { + let centroids_array = FixedSizeListArray::try_new_from_values( + Float32Array::from(centroids.to_vec()), + dim as i32, + )?; + Ok(IvfModel::new(centroids_array, None)) +} + #[unsafe(no_mangle)] pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainIvfCentroids<'local>( mut env: JNIEnv<'local>, _class: JClass<'local>, - dataset_obj: JObject<'local>, // org.lance.Dataset - column_jstr: JString<'local>, // java.lang.String - ivf_params_obj: JObject<'local>, // org.lance.index.vector.IvfBuildParams + dataset_obj: JObject<'local>, // org.lance.Dataset + column_jstr: JString<'local>, // java.lang.String + ivf_params_obj: JObject<'local>, // org.lance.index.vector.IvfBuildParams + fragment_ids_obj: JObject<'local>, // List, nullable ) -> jfloatArray { ok_or_throw_with_return!( env, - inner_train_ivf_centroids(&mut env, dataset_obj, column_jstr, ivf_params_obj) - .map(|arr| arr.into_raw()), + inner_train_ivf_centroids( + &mut env, + dataset_obj, + column_jstr, + ivf_params_obj, + fragment_ids_obj + ) + .map(|arr| arr.into_raw()), JFloatArray::default().into_raw() ) } @@ -97,9 +139,11 @@ fn inner_train_ivf_centroids<'local>( dataset_obj: JObject<'local>, column_jstr: JString<'local>, ivf_params_obj: JObject<'local>, + fragment_ids_obj: JObject<'local>, ) -> Result> { let column: String = env.get_string(&column_jstr)?.into(); let ivf_params = build_ivf_params_from_java(env, &ivf_params_obj)?; + let fragment_ids = get_nullable_fragment_ids(env, &fragment_ids_obj)?; let flattened: Vec = { let dataset_guard = @@ -107,8 +151,6 @@ fn inner_train_ivf_centroids<'local>( let dataset = &dataset_guard.inner; let dim = get_vector_dim(dataset.schema(), &column)?; - - // For now we default to L2 metric; tests and Java bindings currently use L2. let metric_type = MetricType::L2; let ivf_model = RT.block_on(lance::index::vector::ivf::build_ivf_model( @@ -117,7 +159,7 @@ fn inner_train_ivf_centroids<'local>( dim, metric_type, &ivf_params, - None, + fragment_ids.as_deref(), Arc::new(NoopIndexBuildProgress), ))?; @@ -137,14 +179,23 @@ fn inner_train_ivf_centroids<'local>( pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainPqCodebook<'local>( mut env: JNIEnv<'local>, _class: JClass<'local>, - dataset_obj: JObject<'local>, // org.lance.Dataset - column_jstr: JString<'local>, // java.lang.String - pq_params_obj: JObject<'local>, // org.lance.index.vector.PQBuildParams + dataset_obj: JObject<'local>, // org.lance.Dataset + column_jstr: JString<'local>, // java.lang.String + pq_params_obj: JObject<'local>, // org.lance.index.vector.PQBuildParams + ivf_centroids_obj: JObject<'local>, // float[], nullable + fragment_ids_obj: JObject<'local>, // List, nullable ) -> jfloatArray { ok_or_throw_with_return!( env, - inner_train_pq_codebook(&mut env, dataset_obj, column_jstr, pq_params_obj) - .map(|arr| arr.into_raw()), + inner_train_pq_codebook( + &mut env, + dataset_obj, + column_jstr, + pq_params_obj, + ivf_centroids_obj, + fragment_ids_obj + ) + .map(|arr| arr.into_raw()), JFloatArray::default().into_raw() ) } @@ -154,9 +205,13 @@ fn inner_train_pq_codebook<'local>( dataset_obj: JObject<'local>, column_jstr: JString<'local>, pq_params_obj: JObject<'local>, + ivf_centroids_obj: JObject<'local>, + fragment_ids_obj: JObject<'local>, ) -> Result> { let column: String = env.get_string(&column_jstr)?.into(); let pq_params = build_pq_params_from_java(env, &pq_params_obj)?; + let fragment_ids = get_nullable_fragment_ids(env, &fragment_ids_obj)?; + let ivf_centroids_data = get_nullable_ivf_centroids(env, &ivf_centroids_obj)?; let flattened: Vec = { let dataset_guard = @@ -166,13 +221,19 @@ fn inner_train_pq_codebook<'local>( let dim = get_vector_dim(dataset.schema(), &column)?; let metric_type = MetricType::L2; - let pq = RT.block_on(lance::index::vector::pq::build_pq_model( + let ivf_model = ivf_centroids_data + .as_deref() + .map(|data| build_ivf_model_from_centroids(data, dim)) + .transpose()?; + + let pq = RT.block_on(lance::index::vector::pq::build_pq_model_in_fragments( dataset, &column, dim, metric_type, &pq_params, - None, + ivf_model.as_ref(), + fragment_ids.as_deref(), ))?; flatten_fixed_size_list_to_f32(&pq.codebook)? diff --git a/java/src/main/java/org/lance/index/vector/VectorTrainer.java b/java/src/main/java/org/lance/index/vector/VectorTrainer.java index 03081176bf1..4f87d376145 100755 --- a/java/src/main/java/org/lance/index/vector/VectorTrainer.java +++ b/java/src/main/java/org/lance/index/vector/VectorTrainer.java @@ -18,6 +18,8 @@ import org.apache.arrow.util.Preconditions; +import java.util.List; + /** * Training utilities for vector indexes. * @@ -36,38 +38,88 @@ private VectorTrainer() {} /** * Train IVF centroids for the given dataset column. * + *

Training samples from the entire dataset. + * * @param dataset the dataset to sample training data from * @param column the vector column name * @param params IVF build parameters (numPartitions, sampleRate, etc.) * @return a flattened array of centroids laid out as [numPartitions][dimension] */ public static float[] trainIvfCentroids(Dataset dataset, String column, IvfBuildParams params) { + return trainIvfCentroids(dataset, column, params, null); + } + + /** + * Train IVF centroids for the given dataset column, optionally restricted to specific fragments. + * + *

When {@code fragmentIds} is non-null, only the listed fragments are sampled for training. + * This is useful for per-fragment (non-shared centroid) distributed index builds. + * + * @param dataset the dataset to sample training data from + * @param column the vector column name + * @param params IVF build parameters (numPartitions, sampleRate, etc.) + * @param fragmentIds fragment IDs to restrict training to, or {@code null} for the full dataset + * @return a flattened array of centroids laid out as [numPartitions][dimension] + */ + public static float[] trainIvfCentroids( + Dataset dataset, String column, IvfBuildParams params, List fragmentIds) { Preconditions.checkArgument(dataset != null, "dataset cannot be null"); Preconditions.checkArgument( column != null && !column.isEmpty(), "column cannot be null or empty"); Preconditions.checkArgument(params != null, "params cannot be null"); - return nativeTrainIvfCentroids(dataset, column, params); + return nativeTrainIvfCentroids(dataset, column, params, fragmentIds); } /** * Train a PQ codebook for the given dataset column. * + *

Training samples from the entire dataset without IVF residual computation. + * * @param dataset the dataset to sample training data from * @param column the vector column name * @param params PQ build parameters (numSubVectors, numBits, sampleRate, etc.) * @return a flattened array of codebook entries laid out as [num_centroids][dimension] */ public static float[] trainPqCodebook(Dataset dataset, String column, PQBuildParams params) { + return trainPqCodebook(dataset, column, params, null, null); + } + + /** + * Train a PQ codebook for the given dataset column, optionally using pre-trained IVF centroids + * for residual-based training and restricting to specific fragments. + * + *

When {@code ivfCentroids} is non-null, PQ training is performed on the residual vectors + * after IVF assignment (matching the Python {@code train_pq} behavior). When {@code fragmentIds} + * is non-null, only the listed fragments are sampled. + * + * @param dataset the dataset to sample training data from + * @param column the vector column name + * @param params PQ build parameters (numSubVectors, numBits, sampleRate, etc.) + * @param ivfCentroids flattened IVF centroids for residual PQ training, or {@code null} to skip + * residual computation + * @param fragmentIds fragment IDs to restrict training to, or {@code null} for the full dataset + * @return a flattened array of codebook entries laid out as [num_centroids][dimension] + */ + public static float[] trainPqCodebook( + Dataset dataset, + String column, + PQBuildParams params, + float[] ivfCentroids, + List fragmentIds) { Preconditions.checkArgument(dataset != null, "dataset cannot be null"); Preconditions.checkArgument( column != null && !column.isEmpty(), "column cannot be null or empty"); Preconditions.checkArgument(params != null, "params cannot be null"); - return nativeTrainPqCodebook(dataset, column, params); + return nativeTrainPqCodebook(dataset, column, params, ivfCentroids, fragmentIds); } private static native float[] nativeTrainIvfCentroids( - Dataset dataset, String column, IvfBuildParams params); + Dataset dataset, String column, IvfBuildParams params, List fragmentIds); private static native float[] nativeTrainPqCodebook( - Dataset dataset, String column, PQBuildParams params); + Dataset dataset, + String column, + PQBuildParams params, + float[] ivfCentroids, + List fragmentIds); } diff --git a/java/src/test/java/org/lance/index/VectorIndexTest.java b/java/src/test/java/org/lance/index/VectorIndexTest.java index a96b6593d30..9e359ae0e0c 100755 --- a/java/src/test/java/org/lance/index/VectorIndexTest.java +++ b/java/src/test/java/org/lance/index/VectorIndexTest.java @@ -15,7 +15,10 @@ import org.lance.Dataset; import org.lance.Fragment; +import org.lance.FragmentMetadata; +import org.lance.FragmentOperation; import org.lance.TestVectorDataset; +import org.lance.WriteParams; import org.lance.index.vector.IvfBuildParams; import org.lance.index.vector.PQBuildParams; import org.lance.index.vector.RQBuildParams; @@ -23,10 +26,21 @@ import org.lance.index.vector.VectorIndexParams; import org.lance.index.vector.VectorTrainer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -319,4 +333,151 @@ public void testCreateIvfRqIndex(@TempDir Path tempDir) throws Exception { } } } + + @Test + public void testIvfCentroidsWithFragmentIds(@TempDir Path tempDir) throws Exception { + int dim = 8; + int rowsPerFragment = 32; + String column = "vec"; + + Schema schema = + new Schema( + Collections.singletonList( + new Field( + column, + FieldType.nullable(new ArrowType.FixedSizeList(dim)), + Collections.singletonList( + new Field( + "item", + FieldType.nullable( + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null))))); + + try (BufferAllocator allocator = new RootAllocator()) { + Path datasetPath = tempDir.resolve("fragment_ivf"); + WriteParams emptyParams = + new WriteParams.Builder().withMaxRowsPerFile(rowsPerFragment).build(); + Dataset.create(allocator, datasetPath.toString(), schema, emptyParams).close(); + + // Fragment 0: all zeros + FragmentMetadata frag0; + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + FixedSizeListVector vecVector = (FixedSizeListVector) root.getVector(column); + Float4Vector items = (Float4Vector) vecVector.getDataVector(); + for (int i = 0; i < rowsPerFragment; i++) { + for (int j = 0; j < dim; j++) { + items.setSafe(i * dim + j, 0.0f); + } + vecVector.setNotNull(i); + } + root.setRowCount(rowsPerFragment); + frag0 = + Fragment.create( + datasetPath.toString(), allocator, root, new WriteParams.Builder().build()) + .get(0); + } + + // Fragment 1: all 10.0 + FragmentMetadata frag1; + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + FixedSizeListVector vecVector = (FixedSizeListVector) root.getVector(column); + Float4Vector items = (Float4Vector) vecVector.getDataVector(); + for (int i = 0; i < rowsPerFragment; i++) { + for (int j = 0; j < dim; j++) { + items.setSafe(i * dim + j, 10.0f); + } + vecVector.setNotNull(i); + } + root.setRowCount(rowsPerFragment); + frag1 = + Fragment.create( + datasetPath.toString(), allocator, root, new WriteParams.Builder().build()) + .get(0); + } + + List fragments = new ArrayList<>(); + fragments.add(frag0); + fragments.add(frag1); + FragmentOperation.Append appendOp = new FragmentOperation.Append(fragments); + try (Dataset dataset = + Dataset.commit(allocator, datasetPath.toString(), appendOp, Optional.of(1L))) { + + List dsFragments = dataset.getFragments(); + assertEquals(2, dsFragments.size()); + + IvfBuildParams ivfParams = + new IvfBuildParams.Builder() + .setNumPartitions(1) + .setMaxIters(1) + .setSampleRate(2) + .build(); + + // Train IVF on fragment 0 (zeros) only + float[] firstCentroids = + VectorTrainer.trainIvfCentroids( + dataset, column, ivfParams, Collections.singletonList(dsFragments.get(0).getId())); + + // Train IVF on fragment 1 (10.0s) only + float[] secondCentroids = + VectorTrainer.trainIvfCentroids( + dataset, column, ivfParams, Collections.singletonList(dsFragments.get(1).getId())); + + assertEquals(dim, firstCentroids.length); + assertEquals(dim, secondCentroids.length); + + for (int j = 0; j < dim; j++) { + assertEquals(0.0f, firstCentroids[j], 1e-4f, "first centroid[" + j + "] should be ~0.0"); + assertEquals( + 10.0f, secondCentroids[j], 1e-4f, "second centroid[" + j + "] should be ~10.0"); + } + } + } + } + + @Test + public void testPqCodebookWithFragmentIds(@TempDir Path tempDir) throws Exception { + try (TestVectorDataset testVectorDataset = + new TestVectorDataset(tempDir.resolve("pq_fragment_ids"))) { + try (Dataset dataset = testVectorDataset.create()) { + List fragments = dataset.getFragments(); + assertTrue(fragments.size() >= 4, "Expected at least four fragments"); + // Use 4 fragments (320 rows total) to meet PQ sample requirements + List fragmentIds = + List.of( + fragments.get(0).getId(), + fragments.get(1).getId(), + fragments.get(2).getId(), + fragments.get(3).getId()); + + IvfBuildParams ivfParams = + new IvfBuildParams.Builder() + .setNumPartitions(4) + .setMaxIters(1) + .setSampleRate(16) + .build(); + + float[] centroids = + VectorTrainer.trainIvfCentroids( + dataset, TestVectorDataset.vectorColumnName, ivfParams, fragmentIds); + assertNotNull(centroids); + assertTrue(centroids.length > 0, "IVF centroids should not be empty"); + + PQBuildParams pqParams = + new PQBuildParams.Builder() + .setNumSubVectors(2) + .setNumBits(8) + .setMaxIters(2) + .setSampleRate(1) + .build(); + + float[] codebook = + VectorTrainer.trainPqCodebook( + dataset, TestVectorDataset.vectorColumnName, pqParams, centroids, fragmentIds); + assertNotNull(codebook); + assertTrue(codebook.length > 0, "PQ codebook should not be empty"); + } + } + } }