Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions java/lance-jni/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions java/lance-jni/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ default = []

[dependencies]
lance = { path = "../../rust/lance", features = ["substrait"] }
lance-arrow = { path = "../../rust/lance-arrow" }
lance-datafusion = { path = "../../rust/lance-datafusion" }
lance-encoding = { path = "../../rust/lance-encoding" }
lance-linalg = { path = "../../rust/lance-linalg" }
Expand Down
91 changes: 76 additions & 15 deletions java/lance-jni/src/vector_trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use jni::objects::{JClass, JFloatArray, JObject, JString};
use jni::sys::jfloatArray;
use lance::index::NoopIndexBuildProgress;
use lance::index::vector::utils::get_vector_dim;
use lance_arrow::FixedSizeListArrayExt;
use lance_index::vector::ivf::builder::IvfBuildParams as RustIvfBuildParams;
use lance_index::vector::ivf::storage::IvfModel;
use lance_index::vector::pq::builder::PQBuildParams as RustPQBuildParams;
use lance_linalg::distance::MetricType;

Expand Down Expand Up @@ -76,18 +78,58 @@ fn build_pq_params_from_java(
})
}

/// Extract a nullable Java `List<Integer>` into `Option<Vec<u32>>`.
fn get_nullable_fragment_ids(env: &mut JNIEnv, obj: &JObject) -> Result<Option<Vec<u32>>> {
if obj.is_null() {
return Ok(None);
}
let ints = env.get_integers(obj)?;
Ok(Some(ints.into_iter().map(|i| i as u32).collect()))
}

/// Extract a nullable Java `float[]` into `Option<Vec<f32>>`.
fn get_nullable_ivf_centroids(
env: &mut JNIEnv,
centroids_obj: &JObject,
) -> Result<Option<Vec<f32>>> {
if centroids_obj.is_null() {
return Ok(None);
}
let jarray = unsafe { JFloatArray::from_raw(centroids_obj.as_raw()) };
let length = env.get_array_length(&jarray)?;
let mut buffer = vec![0.0f32; length as usize];
env.get_float_array_region(&jarray, 0, &mut buffer)?;
Ok(Some(buffer))
}

/// Build an `IvfModel` from a flat float array and known dimension.
fn build_ivf_model_from_centroids(centroids: &[f32], dim: usize) -> Result<IvfModel> {
let centroids_array = FixedSizeListArray::try_new_from_values(
Float32Array::from(centroids.to_vec()),
dim as i32,
)?;
Ok(IvfModel::new(centroids_array, None))
}

#[unsafe(no_mangle)]
pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainIvfCentroids<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
dataset_obj: JObject<'local>, // org.lance.Dataset
column_jstr: JString<'local>, // java.lang.String
ivf_params_obj: JObject<'local>, // org.lance.index.vector.IvfBuildParams
dataset_obj: JObject<'local>, // org.lance.Dataset
column_jstr: JString<'local>, // java.lang.String
ivf_params_obj: JObject<'local>, // org.lance.index.vector.IvfBuildParams
fragment_ids_obj: JObject<'local>, // List<Integer>, nullable
) -> jfloatArray {
ok_or_throw_with_return!(
env,
inner_train_ivf_centroids(&mut env, dataset_obj, column_jstr, ivf_params_obj)
.map(|arr| arr.into_raw()),
inner_train_ivf_centroids(
&mut env,
dataset_obj,
column_jstr,
ivf_params_obj,
fragment_ids_obj
)
.map(|arr| arr.into_raw()),
JFloatArray::default().into_raw()
)
}
Expand All @@ -97,18 +139,18 @@ fn inner_train_ivf_centroids<'local>(
dataset_obj: JObject<'local>,
column_jstr: JString<'local>,
ivf_params_obj: JObject<'local>,
fragment_ids_obj: JObject<'local>,
) -> Result<JFloatArray<'local>> {
let column: String = env.get_string(&column_jstr)?.into();
let ivf_params = build_ivf_params_from_java(env, &ivf_params_obj)?;
let fragment_ids = get_nullable_fragment_ids(env, &fragment_ids_obj)?;

let flattened: Vec<f32> = {
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(dataset_obj, NATIVE_DATASET) }?;
let dataset = &dataset_guard.inner;

let dim = get_vector_dim(dataset.schema(), &column)?;

// For now we default to L2 metric; tests and Java bindings currently use L2.
let metric_type = MetricType::L2;

let ivf_model = RT.block_on(lance::index::vector::ivf::build_ivf_model(
Expand All @@ -117,7 +159,7 @@ fn inner_train_ivf_centroids<'local>(
dim,
metric_type,
&ivf_params,
None,
fragment_ids.as_deref(),
Arc::new(NoopIndexBuildProgress),
))?;

Expand All @@ -137,14 +179,23 @@ fn inner_train_ivf_centroids<'local>(
pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainPqCodebook<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
dataset_obj: JObject<'local>, // org.lance.Dataset
column_jstr: JString<'local>, // java.lang.String
pq_params_obj: JObject<'local>, // org.lance.index.vector.PQBuildParams
dataset_obj: JObject<'local>, // org.lance.Dataset
column_jstr: JString<'local>, // java.lang.String
pq_params_obj: JObject<'local>, // org.lance.index.vector.PQBuildParams
ivf_centroids_obj: JObject<'local>, // float[], nullable
fragment_ids_obj: JObject<'local>, // List<Integer>, nullable
) -> jfloatArray {
ok_or_throw_with_return!(
env,
inner_train_pq_codebook(&mut env, dataset_obj, column_jstr, pq_params_obj)
.map(|arr| arr.into_raw()),
inner_train_pq_codebook(
&mut env,
dataset_obj,
column_jstr,
pq_params_obj,
ivf_centroids_obj,
fragment_ids_obj
)
.map(|arr| arr.into_raw()),
JFloatArray::default().into_raw()
)
}
Expand All @@ -154,9 +205,13 @@ fn inner_train_pq_codebook<'local>(
dataset_obj: JObject<'local>,
column_jstr: JString<'local>,
pq_params_obj: JObject<'local>,
ivf_centroids_obj: JObject<'local>,
fragment_ids_obj: JObject<'local>,
) -> Result<JFloatArray<'local>> {
let column: String = env.get_string(&column_jstr)?.into();
let pq_params = build_pq_params_from_java(env, &pq_params_obj)?;
let fragment_ids = get_nullable_fragment_ids(env, &fragment_ids_obj)?;
let ivf_centroids_data = get_nullable_ivf_centroids(env, &ivf_centroids_obj)?;

let flattened: Vec<f32> = {
let dataset_guard =
Expand All @@ -166,13 +221,19 @@ fn inner_train_pq_codebook<'local>(
let dim = get_vector_dim(dataset.schema(), &column)?;
let metric_type = MetricType::L2;

let pq = RT.block_on(lance::index::vector::pq::build_pq_model(
let ivf_model = ivf_centroids_data
.as_deref()
.map(|data| build_ivf_model_from_centroids(data, dim))
.transpose()?;

let pq = RT.block_on(lance::index::vector::pq::build_pq_model_in_fragments(
dataset,
&column,
dim,
metric_type,
&pq_params,
None,
ivf_model.as_ref(),
fragment_ids.as_deref(),
))?;

flatten_fixed_size_list_to_f32(&pq.codebook)?
Expand Down
60 changes: 56 additions & 4 deletions java/src/main/java/org/lance/index/vector/VectorTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import org.apache.arrow.util.Preconditions;

import java.util.List;

/**
* Training utilities for vector indexes.
*
Expand All @@ -36,38 +38,88 @@ private VectorTrainer() {}
/**
* Train IVF centroids for the given dataset column.
*
* <p>Training samples from the entire dataset.
*
* @param dataset the dataset to sample training data from
* @param column the vector column name
* @param params IVF build parameters (numPartitions, sampleRate, etc.)
* @return a flattened array of centroids laid out as [numPartitions][dimension]
*/
public static float[] trainIvfCentroids(Dataset dataset, String column, IvfBuildParams params) {
return trainIvfCentroids(dataset, column, params, null);
}

/**
* Train IVF centroids for the given dataset column, optionally restricted to specific fragments.
*
* <p>When {@code fragmentIds} is non-null, only the listed fragments are sampled for training.
* This is useful for per-fragment (non-shared centroid) distributed index builds.
*
* @param dataset the dataset to sample training data from
* @param column the vector column name
* @param params IVF build parameters (numPartitions, sampleRate, etc.)
* @param fragmentIds fragment IDs to restrict training to, or {@code null} for the full dataset
* @return a flattened array of centroids laid out as [numPartitions][dimension]
*/
public static float[] trainIvfCentroids(
Dataset dataset, String column, IvfBuildParams params, List<Integer> fragmentIds) {
Preconditions.checkArgument(dataset != null, "dataset cannot be null");
Preconditions.checkArgument(
column != null && !column.isEmpty(), "column cannot be null or empty");
Preconditions.checkArgument(params != null, "params cannot be null");
return nativeTrainIvfCentroids(dataset, column, params);
return nativeTrainIvfCentroids(dataset, column, params, fragmentIds);
}

/**
* Train a PQ codebook for the given dataset column.
*
* <p>Training samples from the entire dataset without IVF residual computation.
*
* @param dataset the dataset to sample training data from
* @param column the vector column name
* @param params PQ build parameters (numSubVectors, numBits, sampleRate, etc.)
* @return a flattened array of codebook entries laid out as [num_centroids][dimension]
*/
public static float[] trainPqCodebook(Dataset dataset, String column, PQBuildParams params) {
return trainPqCodebook(dataset, column, params, null, null);
}

/**
* Train a PQ codebook for the given dataset column, optionally using pre-trained IVF centroids
* for residual-based training and restricting to specific fragments.
*
* <p>When {@code ivfCentroids} is non-null, PQ training is performed on the residual vectors
* after IVF assignment (matching the Python {@code train_pq} behavior). When {@code fragmentIds}
* is non-null, only the listed fragments are sampled.
*
* @param dataset the dataset to sample training data from
* @param column the vector column name
* @param params PQ build parameters (numSubVectors, numBits, sampleRate, etc.)
* @param ivfCentroids flattened IVF centroids for residual PQ training, or {@code null} to skip
* residual computation
* @param fragmentIds fragment IDs to restrict training to, or {@code null} for the full dataset
* @return a flattened array of codebook entries laid out as [num_centroids][dimension]
*/
public static float[] trainPqCodebook(
Dataset dataset,
String column,
PQBuildParams params,
float[] ivfCentroids,
List<Integer> fragmentIds) {
Preconditions.checkArgument(dataset != null, "dataset cannot be null");
Preconditions.checkArgument(
column != null && !column.isEmpty(), "column cannot be null or empty");
Preconditions.checkArgument(params != null, "params cannot be null");
return nativeTrainPqCodebook(dataset, column, params);
return nativeTrainPqCodebook(dataset, column, params, ivfCentroids, fragmentIds);
}

private static native float[] nativeTrainIvfCentroids(
Dataset dataset, String column, IvfBuildParams params);
Dataset dataset, String column, IvfBuildParams params, List<Integer> fragmentIds);

private static native float[] nativeTrainPqCodebook(
Dataset dataset, String column, PQBuildParams params);
Dataset dataset,
String column,
PQBuildParams params,
float[] ivfCentroids,
List<Integer> fragmentIds);
}
Loading
Loading