diff --git a/rust/chroma/examples/embeddings.rs b/rust/chroma/examples/embeddings.rs new file mode 100644 index 00000000000..74afeaff794 --- /dev/null +++ b/rust/chroma/examples/embeddings.rs @@ -0,0 +1,267 @@ +//! Chroma Cloud embedding collection example. +//! +//! This example demonstrates the default Chroma Cloud Qwen dense embedding +//! function and Splade sparse embedding function by writing embedded records +//! to a Chroma collection and searching them back. +//! +//! # Running +//! +//! Source your environment first, then run the example: +//! +//! ```bash +//! source .env +//! cargo run -p chroma --example embeddings +//! ``` +//! +//! Required environment: +//! +//! ```text +//! CHROMA_API_KEY=... +//! ``` +//! +//! Optional environment: +//! +//! ```text +//! CHROMA_EMBED_URL=... +//! ``` + +use std::error::Error; + +use chroma::{ + embed::{ + chroma_cloud::{ChromaCloudQwenEmbeddingFunction, ChromaCloudSpladeEmbeddingFunction}, + EmbeddingFunction, + }, + types::{ + EmbeddingFunctionConfiguration, IncludeList, Key, Metadata, QueryVector, RankExpr, Schema, + SearchPayload, SearchResponse, SparseVectorIndexConfig, + }, + ChromaCollection, ChromaHttpClient, +}; +use serde_json::{to_string_pretty, Error as JsonError}; + +const COLLECTION_NAME: &str = "rust_chroma_cloud_embeddings_example"; +const DENSE_KEY: &str = "#embedding"; +const SPARSE_KEY: &str = "sparse_embedding"; +const QUERY: &str = "How do I create embeddings with the Rust client?"; + +struct ExampleRecord { + id: &'static str, + topic: &'static str, + document: &'static str, +} + +const RECORDS: [ExampleRecord; 4] = [ + ExampleRecord { + id: "rust-client", + topic: "rust", + document: "The Rust client can use Chroma Cloud Qwen embeddings when records are added.", + }, + ExampleRecord { + id: "sparse-search", + topic: "search", + document: "Splade sparse embeddings help lexical matching and hybrid retrieval.", + }, + ExampleRecord { + id: "collection-schema", + topic: "schema", + document: "Collection schemas can describe dense and sparse vector indexes.", + }, + ExampleRecord { + id: "query-flow", + topic: "query", + document: "Query embeddings let applications retrieve similar documents from Chroma.", + }, +]; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = ChromaHttpClient::cloud()?; + + let qwen = ChromaCloudQwenEmbeddingFunction::builder() + .task("nl_to_code") + .build()?; + let splade = ChromaCloudSpladeEmbeddingFunction::builder() + .include_tokens(true) + .build()?; + + let qwen_config = ChromaCloudQwenEmbeddingFunction::configuration() + .task("nl_to_code") + .build(); + let splade_config = ChromaCloudSpladeEmbeddingFunction::configuration() + .include_tokens(true) + .build(); + + let schema = Schema::default_with_embedding_function(qwen_config).create_index( + Some(SPARSE_KEY), + SparseVectorIndexConfig { + embedding_function: Some(splade_config), + source_key: None, + bm25: Some(false), + algorithm: Default::default(), + } + .into(), + )?; + + let _ = client.delete_collection(COLLECTION_NAME).await; + client + .create_collection(COLLECTION_NAME, Some(schema), None) + .await?; + let collection = client.get_collection(COLLECTION_NAME).await?; + print_saved_embedding_functions(&collection)?; + + let documents = RECORDS + .iter() + .map(|record| record.document) + .collect::>(); + let sparse_embeddings = splade.embed_strs(&documents).await?; + let metadatas = RECORDS + .iter() + .zip(sparse_embeddings) + .map(|(record, sparse_embedding)| { + let mut metadata = Metadata::new(); + metadata.insert("topic".into(), record.topic.into()); + metadata.insert(SPARSE_KEY.into(), sparse_embedding.into()); + Some(metadata) + }) + .collect::>(); + + let ids = RECORDS + .iter() + .map(|record| record.id.to_string()) + .collect::>(); + let document_values = RECORDS + .iter() + .map(|record| Some(record.document.to_string())) + .collect::>(); + + collection + .add( + ids.clone(), + None::>>, + Some(document_values), + None, + Some(metadatas), + ) + .await?; + + let count = collection.count().await?; + println!("Inserted {count} records into '{}'.", collection.name()); + + let retrieved = collection + .get( + Some(ids.clone()), + None, + Some(ids.len() as u32), + Some(0), + Some(IncludeList::default_get()), + ) + .await?; + println!("Round-tripped {} records by ID.", retrieved.ids.len()); + + let dense_query = qwen.embed_query_strs(&[QUERY]).await?.remove(0); + let sparse_query = splade.embed_query_strs(&[QUERY]).await?.remove(0); + + let dense_search = SearchPayload::default() + .rank(RankExpr::Knn { + query: QueryVector::Dense(dense_query), + key: Key::Embedding, + limit: 10, + default: None, + return_rank: false, + }) + .limit(Some(3), 0) + .select([Key::Document, Key::Score, Key::field("topic")]); + + let sparse_search = SearchPayload::default() + .rank(RankExpr::Knn { + query: QueryVector::Sparse(sparse_query), + key: Key::field(SPARSE_KEY), + limit: 10, + default: None, + return_rank: false, + }) + .limit(Some(3), 0) + .select([Key::Document, Key::Score, Key::field("topic")]); + + let results = collection.search(vec![dense_search, sparse_search]).await?; + print_results("Qwen dense search", &results, 0); + print_results("Splade sparse search", &results, 1); + + client.delete_collection(COLLECTION_NAME).await?; + println!("Deleted example collection '{}'.", COLLECTION_NAME); + + Ok(()) +} + +fn print_saved_embedding_functions(collection: &ChromaCollection) -> Result<(), JsonError> { + let schema = collection.schema().as_ref(); + let dense_config = schema + .and_then(|schema| schema.keys.get(DENSE_KEY)) + .and_then(|value_types| value_types.float_list.as_ref()) + .and_then(|float_list| float_list.vector_index.as_ref()) + .and_then(|vector_index| vector_index.config.embedding_function.as_ref()) + .or_else(|| { + schema + .and_then(|schema| schema.defaults.float_list.as_ref()) + .and_then(|float_list| float_list.vector_index.as_ref()) + .and_then(|vector_index| vector_index.config.embedding_function.as_ref()) + }); + let sparse_config = schema + .and_then(|schema| schema.keys.get(SPARSE_KEY)) + .and_then(|value_types| value_types.sparse_vector.as_ref()) + .and_then(|sparse_vector| sparse_vector.sparse_vector_index.as_ref()) + .and_then(|sparse_index| sparse_index.config.embedding_function.as_ref()); + + print_embedding_function_config("Saved dense embedding function", dense_config)?; + print_embedding_function_config("Saved sparse embedding function", sparse_config)?; + + Ok(()) +} + +fn print_embedding_function_config( + label: &str, + config: Option<&EmbeddingFunctionConfiguration>, +) -> Result<(), JsonError> { + match config { + Some(config) => { + println!("{label}:"); + println!("{}", to_string_pretty(config)?); + } + None => println!("{label}: "), + } + println!(); + Ok(()) +} + +fn print_results(label: &str, response: &SearchResponse, search_index: usize) { + println!("\n{label}"); + for (rank, id) in response.ids[search_index].iter().enumerate() { + let score = response.scores[search_index] + .as_ref() + .and_then(|scores| scores.get(rank)) + .and_then(|score| *score) + .map(|score| format!("{score:.4}")) + .unwrap_or_else(|| "N/A".to_string()); + let document = response.documents[search_index] + .as_ref() + .and_then(|documents| documents.get(rank)) + .and_then(|document| document.as_deref()) + .unwrap_or(""); + let topic = response.metadatas[search_index] + .as_ref() + .and_then(|metadatas| metadatas.get(rank)) + .and_then(|metadata| metadata.as_ref()) + .and_then(|metadata| metadata.get("topic")) + .map(|topic| format!("{topic:?}")) + .unwrap_or_else(|| "N/A".to_string()); + println!( + " {}. {} score={} topic={} document={}", + rank + 1, + id, + score, + topic, + document + ); + } +} diff --git a/rust/chroma/src/client/chroma_http_client.rs b/rust/chroma/src/client/chroma_http_client.rs index d7de6831560..536e1ed1227 100644 --- a/rust/chroma/src/client/chroma_http_client.rs +++ b/rust/chroma/src/client/chroma_http_client.rs @@ -60,6 +60,15 @@ pub enum ChromaHttpClientError { /// validation error from the where clause parser. #[error("Invalid where clause")] InvalidWhere, + /// No embedding function is configured for a collection that needs to embed documents. + #[error("You must provide an embedding function to compute embeddings from documents")] + MissingEmbeddingFunction, + /// Documents were required because embeddings were omitted. + #[error("Documents are required when embeddings are not provided")] + MissingDocumentsForEmbedding, + /// The configured embedding function failed. + #[error("Embedding function error: {0}")] + EmbeddingFunctionError(String), } impl From for ChromaHttpClientError { @@ -121,7 +130,10 @@ impl FailurePredicate for BackendFailurePredicate { ChromaHttpClientError::CouldNotResolveDatabaseId(_) | ChromaHttpClientError::ValidationError(_) | ChromaHttpClientError::NoBackendAvailable - | ChromaHttpClientError::InvalidWhere => false, + | ChromaHttpClientError::InvalidWhere + | ChromaHttpClientError::MissingEmbeddingFunction + | ChromaHttpClientError::MissingDocumentsForEmbedding + | ChromaHttpClientError::EmbeddingFunctionError(_) => false, } } } @@ -175,6 +187,7 @@ pub struct ChromaHttpClient { tenant_id: Arc>>, database_name: Arc>>, resolve_tenant_or_database_lock: Arc>, + chroma_cloud_api_key: Option, } impl Default for ChromaHttpClient { @@ -192,6 +205,7 @@ impl Clone for ChromaHttpClient { tenant_id: Arc::new(Mutex::new(self.tenant_id.lock().clone())), database_name: Arc::new(Mutex::new(self.database_name.lock().clone())), resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())), + chroma_cloud_api_key: self.chroma_cloud_api_key.clone(), } } } @@ -232,6 +246,10 @@ impl ChromaHttpClient { /// # } /// ``` pub fn new(options: ChromaHttpClientOptions) -> Self { + let chroma_cloud_api_key = options + .auth_method + .chroma_cloud_api_key() + .map(ToOwned::to_owned); let mut headers = options.headers(); headers.append("user-agent", USER_AGENT.try_into().unwrap()); @@ -263,9 +281,14 @@ impl ChromaHttpClient { tenant_id: Arc::new(Mutex::new(options.tenant_id)), database_name: Arc::new(Mutex::new(options.database_name)), resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())), + chroma_cloud_api_key, } } + pub(crate) fn chroma_cloud_api_key(&self) -> Option<&str> { + self.chroma_cloud_api_key.as_deref() + } + /// Constructs a client from environment variables. /// /// Reads configuration from `CHROMA_ENDPOINT`, `CHROMA_HOST`, `CHROMA_TENANT`, @@ -657,10 +680,7 @@ impl ChromaHttpClient { ) .await?; - Ok(ChromaCollection { - client: self.clone(), - collection: Arc::new(collection), - }) + Ok(ChromaCollection::new(self.clone(), collection)) } /// Retrieves an existing collection by its ID. @@ -706,10 +726,7 @@ impl ChromaHttpClient { ) .await?; - Ok(ChromaCollection { - client: self.clone(), - collection: Arc::new(collection), - }) + Ok(ChromaCollection::new(self.clone(), collection)) } /// Removes a collection and all its records from the database. @@ -852,10 +869,7 @@ impl ChromaHttpClient { Ok(collections .into_iter() - .map(|collection| ChromaCollection { - client: self.clone(), - collection: Arc::new(collection), - }) + .map(|collection| ChromaCollection::new(self.clone(), collection)) .collect()) } @@ -1060,10 +1074,7 @@ impl ChromaHttpClient { ) .await?; - Ok(ChromaCollection { - client: self.clone(), - collection: Arc::new(collection), - }) + Ok(ChromaCollection::new(self.clone(), collection)) } pub(crate) async fn send< diff --git a/rust/chroma/src/client/options.rs b/rust/chroma/src/client/options.rs index 1c282e70d24..fdbdcdb820f 100644 --- a/rust/chroma/src/client/options.rs +++ b/rust/chroma/src/client/options.rs @@ -91,6 +91,17 @@ impl ChromaAuthMethod { value, }) } + + pub(crate) fn chroma_cloud_api_key(&self) -> Option<&str> { + match self { + ChromaAuthMethod::HeaderAuth { header, value } + if header.as_str().eq_ignore_ascii_case("x-chroma-token") => + { + value.to_str().ok() + } + ChromaAuthMethod::HeaderAuth { .. } | ChromaAuthMethod::None => None, + } + } } /// Errors that occur during client configuration construction. diff --git a/rust/chroma/src/collection.rs b/rust/chroma/src/collection.rs index 1f865e2fb93..892cafc93d2 100644 --- a/rust/chroma/src/collection.rs +++ b/rust/chroma/src/collection.rs @@ -13,21 +13,23 @@ //! - **Write operations**: [`add()`](ChromaCollection::add), [`update()`](ChromaCollection::update), [`upsert()`](ChromaCollection::upsert), [`delete()`](ChromaCollection::delete), [`modify()`](ChromaCollection::modify) //! - **Attached functions**: [`attach_function()`](ChromaCollection::attach_function), [`get_attached_function()`](ChromaCollection::get_attached_function), [`detach_function()`](ChromaCollection::detach_function) -use std::sync::Arc; +use std::{error::Error, fmt::Display, sync::Arc}; use chroma_api_types::ForkCollectionPayload; use chroma_types::{ plan::{ReadLevel, SearchPayload}, AddCollectionRecordsRequest, AddCollectionRecordsResponse, AttachFunctionResponse, Collection, CollectionUuid, DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse, - DetachFunctionResponse, GetAttachedFunctionResponse, GetRequest, GetResponse, IncludeList, - IndexStatusResponse, Metadata, QueryRequest, QueryResponse, Schema, SearchRequest, - SearchResponse, UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse, - UpdateMetadata, UpsertCollectionRecordsRequest, UpsertCollectionRecordsResponse, Where, + DetachFunctionResponse, EmbeddingFunctionConfiguration, GetAttachedFunctionResponse, + GetRequest, GetResponse, IncludeList, IndexStatusResponse, Metadata, QueryRequest, + QueryResponse, Schema, SearchRequest, SearchResponse, UpdateCollectionRecordsRequest, + UpdateCollectionRecordsResponse, UpdateMetadata, UpsertCollectionRecordsRequest, + UpsertCollectionRecordsResponse, Where, EMBEDDING_KEY, }; use reqwest::Method; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use crate::embed::{chroma_cloud::ChromaCloudQwenEmbeddingFunction, EmbeddingFunction}; use crate::{client::ChromaHttpClientError, ChromaHttpClient}; #[derive(Deserialize)] @@ -65,6 +67,126 @@ struct ForkCountResponse { pub struct ChromaCollection { pub(crate) client: ChromaHttpClient, pub(crate) collection: Arc, + pub(crate) embedding_function: Option>, +} + +type ErasedEmbeddingFunction = + dyn EmbeddingFunction, Error = BoxedEmbeddingError>; + +#[derive(Debug)] +pub(crate) struct BoxedEmbeddingError(Box); + +impl Display for BoxedEmbeddingError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl Error for BoxedEmbeddingError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(self.0.as_ref()) + } +} + +struct ErrorErasedEmbeddingFunction { + inner: E, +} + +#[async_trait::async_trait] +impl EmbeddingFunction for ErrorErasedEmbeddingFunction +where + E: EmbeddingFunction>, + E::Error: Send + Sync + 'static, +{ + type Embedding = Vec; + type Error = BoxedEmbeddingError; + + async fn embed_strs(&self, batches: &[&str]) -> Result, Self::Error> { + self.inner + .embed_strs(batches) + .await + .map_err(|err| BoxedEmbeddingError(Box::new(err))) + } + + async fn embed_query_strs( + &self, + batches: &[&str], + ) -> Result, Self::Error> { + self.inner + .embed_query_strs(batches) + .await + .map_err(|err| BoxedEmbeddingError(Box::new(err))) + } +} + +fn erase_embedding_function(embedding_function: E) -> Arc +where + E: EmbeddingFunction>, + E::Error: Send + Sync + 'static, +{ + Arc::new(ErrorErasedEmbeddingFunction { + inner: embedding_function, + }) as Arc +} + +fn default_embedding_function( + client: &ChromaHttpClient, + collection: &Collection, +) -> Option> { + let config = dense_embedding_function_configuration(collection)?; + match config { + EmbeddingFunctionConfiguration::Known(config) + if config.name == ChromaCloudQwenEmbeddingFunction::name() => + { + ChromaCloudQwenEmbeddingFunction::try_from_config(config, client.chroma_cloud_api_key()) + .map(erase_embedding_function) + .ok() + } + EmbeddingFunctionConfiguration::Legacy + | EmbeddingFunctionConfiguration::Known(_) + | EmbeddingFunctionConfiguration::Unknown => None, + } +} + +fn dense_embedding_function_configuration( + collection: &Collection, +) -> Option<&EmbeddingFunctionConfiguration> { + let schema = collection.schema.as_ref()?; + schema + .keys + .get(EMBEDDING_KEY) + .and_then(|value_types| value_types.float_list.as_ref()) + .and_then(|float_list| float_list.vector_index.as_ref()) + .and_then(|vector_index| vector_index.config.embedding_function.as_ref()) + .or_else(|| { + schema + .defaults + .float_list + .as_ref() + .and_then(|float_list| float_list.vector_index.as_ref()) + .and_then(|vector_index| vector_index.config.embedding_function.as_ref()) + }) +} + +/// Converts an embeddings argument into the optional dense embedding form used by writes. +/// +/// This keeps the write methods to a single public method while accepting either explicit +/// embeddings or `None` to request embedding-function fallback. +pub trait IntoOptionalEmbeddings { + /// Converts this value into optional dense embeddings. + fn into_optional_embeddings(self) -> Option>>; +} + +impl IntoOptionalEmbeddings for Vec> { + fn into_optional_embeddings(self) -> Option>> { + Some(self) + } +} + +impl IntoOptionalEmbeddings for Option>> { + fn into_optional_embeddings(self) -> Option>> { + self + } } impl std::fmt::Debug for ChromaCollection { @@ -80,6 +202,28 @@ impl std::fmt::Debug for ChromaCollection { } impl ChromaCollection { + pub(crate) fn new(client: ChromaHttpClient, collection: Collection) -> Self { + let embedding_function = default_embedding_function(&client, &collection); + Self { + client, + collection: Arc::new(collection), + embedding_function, + } + } + + /// Sets the embedding function used when record embeddings are omitted. + /// + /// Passing `None` clears the callback. When set, [`add`](Self::add), + /// [`update`](Self::update), and [`upsert`](Self::upsert) can compute dense + /// embeddings from provided documents. + pub fn set_embedding_function(&mut self, embedding_function: Option) + where + E: EmbeddingFunction>, + E::Error: Send + Sync + 'static, + { + self.embedding_function = embedding_function.map(erase_embedding_function); + } + /// Returns the database ID that contains this collection. pub fn database(&self) -> &str { &self.collection.database @@ -617,8 +761,10 @@ impl ChromaCollection { /// Inserts new records into the collection. /// - /// All provided vectors must have lengths equal: `ids`, `embeddings`, and optionally + /// All provided vectors must have lengths equal: `ids`, optional `embeddings`, and optionally /// `documents`, `uris`, and `metadatas`. Records with duplicate IDs will cause an error. + /// If `embeddings` is `None`, the collection's embedding function computes embeddings from + /// `documents`. /// /// # Errors /// @@ -634,7 +780,7 @@ impl ChromaCollection { /// # async fn example(collection: ChromaCollection) -> Result<(), Box> { /// let response = collection.add( /// vec!["doc1".to_string(), "doc2".to_string()], - /// vec![vec![0.1, 0.2], vec![0.3, 0.4]], + /// Some(vec![vec![0.1, 0.2], vec![0.3, 0.4]]), /// Some(vec![Some("First document".to_string()), Some("Second document".to_string())]), /// None, /// None @@ -646,11 +792,14 @@ impl ChromaCollection { pub async fn add( &self, ids: Vec, - embeddings: Vec>, + embeddings: impl IntoOptionalEmbeddings, documents: Option>>, uris: Option>>, metadatas: Option>>, ) -> Result { + let embeddings = self + .resolve_embeddings(embeddings.into_optional_embeddings(), &documents) + .await?; let request = AddCollectionRecordsRequest::try_new( self.collection.tenant.clone(), self.collection.database.clone(), @@ -669,7 +818,10 @@ impl ChromaCollection { /// Modifies existing records in the collection. /// /// Updates only the specified fields for records matching the provided IDs. Fields set to - /// `None` or `Some(None)` remain unchanged. All non-`None` vectors must match the length of `ids`. + /// `None` or `Some(None)` remain unchanged. All non-`None` vectors must match the length of + /// `ids`. If `embeddings` is `None`, `documents` contains values, and an embedding function is + /// configured, the collection's embedding function computes updated embeddings for those + /// documents. /// /// # Errors /// @@ -702,6 +854,9 @@ impl ChromaCollection { uris: Option>>, metadatas: Option>>, ) -> Result { + let embeddings = self + .resolve_update_embeddings(embeddings, &documents) + .await?; let request = UpdateCollectionRecordsRequest::try_new( self.collection.tenant.clone(), self.collection.database.clone(), @@ -721,6 +876,8 @@ impl ChromaCollection { /// /// For each ID: if the record exists, updates it; otherwise, inserts a new record. /// This combines the semantics of [`add`](Self::add) and [`update`](Self::update) in a single operation. + /// If `embeddings` is `None`, the collection's embedding function computes embeddings from + /// `documents`. /// /// # Errors /// @@ -736,7 +893,7 @@ impl ChromaCollection { /// # async fn example(collection: ChromaCollection) -> Result<(), Box> { /// let response = collection.upsert( /// vec!["doc1".to_string(), "doc2".to_string()], - /// vec![vec![0.1, 0.2], vec![0.3, 0.4]], + /// Some(vec![vec![0.1, 0.2], vec![0.3, 0.4]]), /// Some(vec![Some("Document 1".to_string()), Some("Document 2".to_string())]), /// None, /// None @@ -748,11 +905,14 @@ impl ChromaCollection { pub async fn upsert( &self, ids: Vec, - embeddings: Vec>, + embeddings: impl IntoOptionalEmbeddings, documents: Option>>, uris: Option>>, metadatas: Option>>, ) -> Result { + let embeddings = self + .resolve_embeddings(embeddings.into_optional_embeddings(), &documents) + .await?; let request = UpsertCollectionRecordsRequest::try_new( self.collection.tenant.clone(), self.collection.database.clone(), @@ -849,6 +1009,7 @@ impl ChromaCollection { Ok(ChromaCollection { client: self.client.clone(), collection: Arc::new(collection), + embedding_function: self.embedding_function.clone(), }) } @@ -994,6 +1155,90 @@ impl ChromaCollection { Ok(response.success) } + async fn resolve_embeddings( + &self, + embeddings: Option>>, + documents: &Option>>, + ) -> Result>, ChromaHttpClientError> { + if let Some(embeddings) = embeddings { + return Ok(embeddings); + } + + let documents = documents + .as_ref() + .ok_or(ChromaHttpClientError::MissingDocumentsForEmbedding)?; + let input = documents + .iter() + .map(|document| { + document + .as_deref() + .ok_or(ChromaHttpClientError::MissingDocumentsForEmbedding) + }) + .collect::, _>>()?; + + self.embed_documents(&input).await + } + + async fn resolve_update_embeddings( + &self, + embeddings: Option>>>, + documents: &Option>>, + ) -> Result>>>, ChromaHttpClientError> { + if embeddings.is_some() || documents.is_none() || self.embedding_function.is_none() { + return Ok(embeddings); + } + + let documents = documents.as_ref().expect("checked above"); + let input = documents + .iter() + .filter_map(|document| document.as_deref()) + .collect::>(); + if input.is_empty() { + return Ok(None); + } + + let mut embeddings = self.embed_documents(&input).await?.into_iter(); + documents + .iter() + .map(|document| { + document + .as_ref() + .map(|_| { + embeddings.next().ok_or_else(|| { + ChromaHttpClientError::EmbeddingFunctionError( + "Embedding function returned fewer embeddings than documents" + .to_string(), + ) + }) + }) + .transpose() + }) + .collect::, _>>() + .map(Some) + } + + async fn embed_documents( + &self, + input: &[&str], + ) -> Result>, ChromaHttpClientError> { + let embedding_function = self + .embedding_function + .as_ref() + .ok_or(ChromaHttpClientError::MissingEmbeddingFunction)?; + let embeddings = embedding_function + .embed_strs(input) + .await + .map_err(|err| ChromaHttpClientError::EmbeddingFunctionError(err.to_string()))?; + if embeddings.len() != input.len() { + return Err(ChromaHttpClientError::EmbeddingFunctionError(format!( + "Embedding function returned {} embeddings for {} inputs", + embeddings.len(), + input.len() + ))); + } + Ok(embeddings) + } + /// Internal transport method that constructs collection-specific API paths and delegates to the client. async fn send( &self, @@ -1045,12 +1290,133 @@ impl ChromaCollection { #[cfg(test)] mod tests { use crate::tests::{unique_collection_name, with_client}; + use crate::{ + client::{ChromaAuthMethod, ChromaHttpClientError, ChromaHttpClientOptions}, + embed::{chroma_cloud::ChromaCloudQwenEmbeddingFunction, EmbeddingFunction}, + ChromaCollection, ChromaHttpClient, + }; use chroma_types::operator::{Key, QueryVector, RankExpr}; use chroma_types::plan::{ReadLevel, SearchPayload}; use chroma_types::{ - Include, IncludeList, Metadata, MetadataComparison, MetadataExpression, MetadataValue, - PrimitiveOperator, UpdateMetadata, UpdateMetadataValue, Where, + Collection, Include, IncludeList, Metadata, MetadataComparison, MetadataExpression, + MetadataValue, PrimitiveOperator, Schema, UpdateMetadata, UpdateMetadataValue, Where, }; + use std::sync::Arc; + + #[derive(Debug, thiserror::Error)] + #[error("test embedding failed")] + struct TestEmbeddingError; + + struct TestEmbeddingFunction; + + #[async_trait::async_trait] + impl EmbeddingFunction for TestEmbeddingFunction { + type Embedding = Vec; + type Error = TestEmbeddingError; + + async fn embed_strs(&self, batches: &[&str]) -> Result>, Self::Error> { + Ok(batches + .iter() + .map(|document| vec![document.len() as f32]) + .collect()) + } + } + + fn test_collection() -> ChromaCollection { + let collection = Collection { + tenant: "tenant".to_string(), + database: "database".to_string(), + ..Default::default() + }; + ChromaCollection { + client: ChromaHttpClient::default(), + collection: Arc::new(collection), + embedding_function: None, + } + } + + #[test] + fn new_auto_attaches_chroma_cloud_qwen_embedding_function() { + let client = ChromaHttpClient::new(ChromaHttpClientOptions { + auth_method: ChromaAuthMethod::cloud_api_key("test-api-key").unwrap(), + ..Default::default() + }); + let collection = Collection { + tenant: "tenant".to_string(), + database: "database".to_string(), + schema: Some(Schema::default_with_embedding_function( + ChromaCloudQwenEmbeddingFunction::configuration().build(), + )), + ..Default::default() + }; + + let collection = ChromaCollection::new(client, collection); + + assert!(collection.embedding_function.is_some()); + } + + #[tokio::test] + async fn resolve_embeddings_uses_embedding_function() { + let mut collection = test_collection(); + collection.set_embedding_function(Some(TestEmbeddingFunction)); + + let embeddings = collection + .resolve_embeddings( + None, + &Some(vec![Some("alpha".to_string()), Some("beta".to_string())]), + ) + .await + .unwrap(); + + assert_eq!(embeddings, vec![vec![5.0], vec![4.0]]); + } + + #[tokio::test] + async fn resolve_embeddings_requires_documents() { + let mut collection = test_collection(); + collection.set_embedding_function(Some(TestEmbeddingFunction)); + + let err = collection + .resolve_embeddings(None, &None) + .await + .unwrap_err(); + + assert!(matches!( + err, + ChromaHttpClientError::MissingDocumentsForEmbedding + )); + } + + #[tokio::test] + async fn resolve_update_embeddings_embeds_only_present_documents() { + let mut collection = test_collection(); + collection.set_embedding_function(Some(TestEmbeddingFunction)); + + let embeddings = collection + .resolve_update_embeddings( + None, + &Some(vec![Some("alpha".to_string()), None, Some("z".to_string())]), + ) + .await + .unwrap(); + + assert_eq!( + embeddings, + Some(vec![Some(vec![5.0]), None, Some(vec![1.0])]) + ); + } + + #[tokio::test] + async fn resolve_update_embeddings_allows_documents_without_embedding_function() { + let collection = test_collection(); + + let embeddings = collection + .resolve_update_embeddings(None, &Some(vec![Some("updated document".to_string())])) + .await + .unwrap(); + + assert_eq!(embeddings, None); + } #[tokio::test] #[test_log::test] diff --git a/rust/chroma/src/embed/chroma_cloud.rs b/rust/chroma/src/embed/chroma_cloud.rs new file mode 100644 index 00000000000..18d7daa0e6f --- /dev/null +++ b/rust/chroma/src/embed/chroma_cloud.rs @@ -0,0 +1,901 @@ +//! Chroma Cloud embedding function implementations. +//! +//! This module mirrors the Python `chromadb` Chroma Cloud embedding functions: +//! `chroma-cloud-qwen` for dense embeddings and `chroma-cloud-splade` for sparse embeddings. + +use std::{collections::HashMap, env}; + +use chroma_types::{ + EmbeddingFunctionConfiguration, EmbeddingFunctionNewConfiguration, SparseVector, + SparseVectorLengthMismatch, +}; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue}; +use serde::{Deserialize, Serialize}; +use serde_json::{from_value, json, Value}; +use thiserror::Error; + +use crate::embed::EmbeddingFunction; + +const DEFAULT_CHROMA_EMBED_URL: &str = "https://embed.trychroma.com"; +const DEFAULT_API_KEY_ENV_VAR: &str = "CHROMA_API_KEY"; +const QWEN_NAME: &str = "chroma-cloud-qwen"; +const SPLADE_NAME: &str = "chroma-cloud-splade"; + +/// Errors returned by Chroma Cloud embedding functions. +#[derive(Debug, Error)] +pub enum ChromaCloudEmbeddingError { + /// No API key was supplied and the configured environment variable was unset. + #[error("API key not found in environment variable {env_var}")] + MissingApiKey { + /// Environment variable checked for the API key. + env_var: String, + }, + /// An API key could not be converted to an HTTP header value. + #[error("invalid API key header value: {0}")] + InvalidHeaderValue(#[from] InvalidHeaderValue), + /// The configured model is not supported by this Rust client. + #[error("unsupported Chroma Cloud embedding model: {0}")] + UnsupportedModel(String), + /// The HTTP request failed. + #[error("request failed: {0}")] + Request(#[from] reqwest::Error), + /// Chroma Cloud returned an error payload or an unexpected payload shape. + #[error("Chroma Cloud embedding API error: {0}")] + Api(String), + /// Chroma Cloud returned an invalid sparse vector. + #[error("invalid sparse vector: {0}")] + SparseVector(#[from] SparseVectorLengthMismatch), + /// A Chroma Cloud embedding function configuration was invalid. + #[error("invalid Chroma Cloud embedding function config: {0}")] + InvalidConfig(String), +} + +/// Dense Chroma Cloud Qwen embedding model. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)] +pub enum ChromaCloudQwenEmbeddingModel { + /// `Qwen/Qwen3-Embedding-0.6B`. + #[serde(rename = "Qwen/Qwen3-Embedding-0.6B")] + #[default] + Qwen3Embedding0p6b, +} + +impl ChromaCloudQwenEmbeddingModel { + /// Returns the model identifier sent to Chroma Cloud. + pub fn as_str(self) -> &'static str { + match self { + Self::Qwen3Embedding0p6b => "Qwen/Qwen3-Embedding-0.6B", + } + } +} + +/// Sparse Chroma Cloud Splade embedding model. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)] +pub enum ChromaCloudSpladeEmbeddingModel { + /// `prithivida/Splade_PP_en_v1`. + #[serde(rename = "prithivida/Splade_PP_en_v1")] + #[default] + SpladePpEnV1, +} + +impl ChromaCloudSpladeEmbeddingModel { + /// Returns the model identifier sent to Chroma Cloud. + pub fn as_str(self) -> &'static str { + match self { + Self::SpladePpEnV1 => "prithivida/Splade_PP_en_v1", + } + } +} + +/// Per-task Qwen instructions for document and query embeddings. +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct ChromaCloudQwenTaskInstructions { + /// Instruction used when embedding documents. + pub documents: String, + /// Instruction used when embedding queries. + pub query: String, +} + +/// Chroma Cloud Qwen embedding function. +/// +/// This implementation calls the Chroma Cloud embedding endpoint and produces dense +/// `Vec` embeddings. +pub struct ChromaCloudQwenEmbeddingFunction { + client: reqwest::Client, + api_url: String, + model: ChromaCloudQwenEmbeddingModel, + task: Option, + instructions: HashMap, + api_key_env_var: String, +} + +/// Fluent builder for [`ChromaCloudQwenEmbeddingFunction`]. +#[derive(Clone, Debug)] +pub struct ChromaCloudQwenEmbeddingFunctionBuilder { + api_key: Option, + client_api_key: Option, + api_key_env_var: String, + embed_url: Option, + model: ChromaCloudQwenEmbeddingModel, + task: Option, + instructions: HashMap, +} + +impl Default for ChromaCloudQwenEmbeddingFunctionBuilder { + fn default() -> Self { + Self { + api_key: None, + client_api_key: None, + api_key_env_var: DEFAULT_API_KEY_ENV_VAR.to_string(), + embed_url: None, + model: ChromaCloudQwenEmbeddingModel::default(), + task: None, + instructions: default_qwen_instructions(), + } + } +} + +impl ChromaCloudQwenEmbeddingFunctionBuilder { + /// Sets an explicit Chroma Cloud API key. + pub fn api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + pub(crate) fn client_api_key(mut self, api_key: impl Into) -> Self { + self.client_api_key = Some(api_key.into()); + self + } + + /// Sets the environment variable used to look up the Chroma Cloud API key. + /// + /// Defaults to `CHROMA_API_KEY`. + pub fn api_key_env_var(mut self, api_key_env_var: impl Into) -> Self { + self.api_key_env_var = api_key_env_var.into(); + self + } + + /// Sets the Chroma Cloud embedding endpoint. + /// + /// Defaults to `CHROMA_EMBED_URL` when set, otherwise `https://embed.trychroma.com`. + pub fn embed_url(mut self, embed_url: impl Into) -> Self { + self.embed_url = Some(embed_url.into()); + self + } + + /// Sets the Qwen model. + pub fn model(mut self, model: ChromaCloudQwenEmbeddingModel) -> Self { + self.model = model; + self + } + + /// Sets the task used to choose document/query instructions. + pub fn task(mut self, task: impl Into) -> Self { + self.task = Some(task.into()); + self + } + + /// Clears the configured task. + pub fn without_task(mut self) -> Self { + self.task = None; + self + } + + /// Replaces the full instruction map. + pub fn instructions( + mut self, + instructions: HashMap, + ) -> Self { + self.instructions = instructions; + self + } + + /// Adds or replaces instructions for one task. + pub fn instruction( + mut self, + task: impl Into, + documents: impl Into, + query: impl Into, + ) -> Self { + insert_qwen_instruction(&mut self.instructions, task, documents, query); + self + } + + /// Builds the embedding function. + /// + /// The API key is resolved in this order: explicit `api_key`, `api_key_env_var`, + /// then the client API key used by collection auto-configuration. + pub fn build(self) -> Result { + let api_key = resolve_api_key(self.api_key, &self.api_key_env_var, self.client_api_key)?; + let client = new_chroma_cloud_client(api_key, self.model.as_str())?; + Ok(ChromaCloudQwenEmbeddingFunction { + client, + api_url: trim_trailing_slash(self.embed_url.unwrap_or_else(chroma_embed_url_from_env)), + model: self.model, + task: self.task, + instructions: self.instructions, + api_key_env_var: self.api_key_env_var, + }) + } +} + +/// Fluent builder for the `chroma-cloud-qwen` known embedding function configuration. +#[derive(Clone, Debug)] +pub struct ChromaCloudQwenEmbeddingConfigurationBuilder { + api_key_env_var: String, + model: ChromaCloudQwenEmbeddingModel, + task: Option, + instructions: HashMap, +} + +impl Default for ChromaCloudQwenEmbeddingConfigurationBuilder { + fn default() -> Self { + Self { + api_key_env_var: DEFAULT_API_KEY_ENV_VAR.to_string(), + model: ChromaCloudQwenEmbeddingModel::default(), + task: None, + instructions: default_qwen_instructions(), + } + } +} + +impl ChromaCloudQwenEmbeddingConfigurationBuilder { + /// Sets the API-key environment variable serialized into the configuration. + pub fn api_key_env_var(mut self, api_key_env_var: impl Into) -> Self { + self.api_key_env_var = api_key_env_var.into(); + self + } + + /// Sets the Qwen model serialized into the configuration. + pub fn model(mut self, model: ChromaCloudQwenEmbeddingModel) -> Self { + self.model = model; + self + } + + /// Sets the task serialized into the configuration. + pub fn task(mut self, task: impl Into) -> Self { + self.task = Some(task.into()); + self + } + + /// Clears the task serialized into the configuration. + pub fn without_task(mut self) -> Self { + self.task = None; + self + } + + /// Replaces the full instruction map serialized into the configuration. + pub fn instructions( + mut self, + instructions: HashMap, + ) -> Self { + self.instructions = instructions; + self + } + + /// Adds or replaces instructions for one task in the serialized configuration. + pub fn instruction( + mut self, + task: impl Into, + documents: impl Into, + query: impl Into, + ) -> Self { + insert_qwen_instruction(&mut self.instructions, task, documents, query); + self + } + + /// Builds the known embedding function configuration. + pub fn build(self) -> EmbeddingFunctionConfiguration { + known_embedding_function_configuration( + QWEN_NAME, + qwen_config_value( + &self.api_key_env_var, + self.model, + self.task.as_deref(), + &self.instructions, + ), + ) + } +} + +impl ChromaCloudQwenEmbeddingFunction { + /// Returns the known embedding function name used in collection configuration. + pub fn name() -> &'static str { + QWEN_NAME + } + + /// Returns a fluent builder for Qwen embedding functions. + pub fn builder() -> ChromaCloudQwenEmbeddingFunctionBuilder { + ChromaCloudQwenEmbeddingFunctionBuilder::default() + } + + /// Returns a fluent builder for `chroma-cloud-qwen` collection configuration. + pub fn configuration() -> ChromaCloudQwenEmbeddingConfigurationBuilder { + ChromaCloudQwenEmbeddingConfigurationBuilder::default() + } + + /// Constructs a Qwen embedding function from a known embedding function configuration. + /// + /// The API key is read from the configured `api_key_env_var` first. If that environment + /// variable is unset, `client_api_key` is used. + /// + /// # Errors + /// + /// Returns an error if the configuration is not `chroma-cloud-qwen`, the model is + /// unsupported, or no API key is available. + pub(crate) fn try_from_config( + config: &EmbeddingFunctionNewConfiguration, + client_api_key: Option<&str>, + ) -> Result { + if config.name != QWEN_NAME { + return Err(ChromaCloudEmbeddingError::InvalidConfig(format!( + "expected {QWEN_NAME}, got {}", + config.name + ))); + } + let config: QwenConfig = from_value(config.config.clone()) + .map_err(|err| ChromaCloudEmbeddingError::InvalidConfig(err.to_string()))?; + let api_key_env_var = config + .api_key_env_var + .unwrap_or_else(|| DEFAULT_API_KEY_ENV_VAR.to_string()); + let mut builder = Self::builder() + .api_key_env_var(api_key_env_var) + .model(config.model) + .instructions( + config + .instructions + .unwrap_or_else(default_qwen_instructions), + ); + if let Some(api_key) = client_api_key { + builder = builder.client_api_key(api_key); + } + if let Some(task) = config.task { + builder = builder.task(task); + } + builder.build() + } + + /// Returns this embedding function's serializable configuration. + pub fn get_config(&self) -> Value { + qwen_config_value( + &self.api_key_env_var, + self.model, + self.task.as_deref(), + &self.instructions, + ) + } + + async fn embed_with_instruction( + &self, + batches: &[&str], + instruction: &str, + ) -> Result>, ChromaCloudEmbeddingError> { + if batches.is_empty() { + return Ok(Vec::new()); + } + let request = DenseEmbeddingRequest { + instructions: instruction, + texts: batches, + }; + let response = self + .client + .post(&self.api_url) + .json(&request) + .send() + .await? + .error_for_status()? + .json::() + .await?; + response.into_embeddings() + } + + fn document_instruction(&self) -> &str { + self.task + .as_ref() + .and_then(|task| self.instructions.get(task)) + .map(|instructions| instructions.documents.as_str()) + .unwrap_or("") + } + + fn query_instruction(&self) -> &str { + self.task + .as_ref() + .and_then(|task| self.instructions.get(task)) + .map(|instructions| instructions.query.as_str()) + .unwrap_or("") + } +} + +#[async_trait::async_trait] +impl EmbeddingFunction for ChromaCloudQwenEmbeddingFunction { + type Embedding = Vec; + type Error = ChromaCloudEmbeddingError; + + async fn embed_strs(&self, batches: &[&str]) -> Result>, Self::Error> { + self.embed_with_instruction(batches, self.document_instruction()) + .await + } + + async fn embed_query_strs(&self, batches: &[&str]) -> Result>, Self::Error> { + self.embed_with_instruction(batches, self.query_instruction()) + .await + } +} + +/// Chroma Cloud Splade sparse embedding function. +/// +/// This implementation calls the Chroma Cloud sparse embedding endpoint and produces +/// [`SparseVector`] embeddings. +pub struct ChromaCloudSpladeEmbeddingFunction { + client: reqwest::Client, + api_url: String, + model: ChromaCloudSpladeEmbeddingModel, + include_tokens: bool, + api_key_env_var: String, +} + +/// Fluent builder for [`ChromaCloudSpladeEmbeddingFunction`]. +#[derive(Clone, Debug)] +pub struct ChromaCloudSpladeEmbeddingFunctionBuilder { + api_key: Option, + api_key_env_var: String, + embed_url: Option, + model: ChromaCloudSpladeEmbeddingModel, + include_tokens: bool, +} + +impl Default for ChromaCloudSpladeEmbeddingFunctionBuilder { + fn default() -> Self { + Self { + api_key: None, + api_key_env_var: DEFAULT_API_KEY_ENV_VAR.to_string(), + embed_url: None, + model: ChromaCloudSpladeEmbeddingModel::default(), + include_tokens: false, + } + } +} + +impl ChromaCloudSpladeEmbeddingFunctionBuilder { + /// Sets an explicit Chroma Cloud API key. + pub fn api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Sets the environment variable used to look up the Chroma Cloud API key. + /// + /// Defaults to `CHROMA_API_KEY`. + pub fn api_key_env_var(mut self, api_key_env_var: impl Into) -> Self { + self.api_key_env_var = api_key_env_var.into(); + self + } + + /// Sets the Chroma Cloud embedding endpoint. + /// + /// Defaults to `CHROMA_EMBED_URL` when set, otherwise `https://embed.trychroma.com`. + pub fn embed_url(mut self, embed_url: impl Into) -> Self { + self.embed_url = Some(embed_url.into()); + self + } + + /// Sets the Splade model. + pub fn model(mut self, model: ChromaCloudSpladeEmbeddingModel) -> Self { + self.model = model; + self + } + + /// Sets whether sparse vectors include token labels. + pub fn include_tokens(mut self, include_tokens: bool) -> Self { + self.include_tokens = include_tokens; + self + } + + /// Builds the embedding function. + /// + /// The API key is resolved in this order: explicit `api_key`, then `api_key_env_var`. + pub fn build(self) -> Result { + let api_key = resolve_api_key(self.api_key, &self.api_key_env_var, None)?; + let client = new_chroma_cloud_client(api_key, self.model.as_str())?; + Ok(ChromaCloudSpladeEmbeddingFunction { + client, + api_url: format!( + "{}/embed_sparse", + trim_trailing_slash(self.embed_url.unwrap_or_else(chroma_embed_url_from_env)) + ), + model: self.model, + include_tokens: self.include_tokens, + api_key_env_var: self.api_key_env_var, + }) + } +} + +/// Fluent builder for the `chroma-cloud-splade` known embedding function configuration. +#[derive(Clone, Debug)] +pub struct ChromaCloudSpladeEmbeddingConfigurationBuilder { + api_key_env_var: String, + model: ChromaCloudSpladeEmbeddingModel, + include_tokens: bool, +} + +impl Default for ChromaCloudSpladeEmbeddingConfigurationBuilder { + fn default() -> Self { + Self { + api_key_env_var: DEFAULT_API_KEY_ENV_VAR.to_string(), + model: ChromaCloudSpladeEmbeddingModel::default(), + include_tokens: false, + } + } +} + +impl ChromaCloudSpladeEmbeddingConfigurationBuilder { + /// Sets the API-key environment variable serialized into the configuration. + pub fn api_key_env_var(mut self, api_key_env_var: impl Into) -> Self { + self.api_key_env_var = api_key_env_var.into(); + self + } + + /// Sets the Splade model serialized into the configuration. + pub fn model(mut self, model: ChromaCloudSpladeEmbeddingModel) -> Self { + self.model = model; + self + } + + /// Sets whether token labels should be included in the serialized configuration. + pub fn include_tokens(mut self, include_tokens: bool) -> Self { + self.include_tokens = include_tokens; + self + } + + /// Builds the known embedding function configuration. + pub fn build(self) -> EmbeddingFunctionConfiguration { + known_embedding_function_configuration( + SPLADE_NAME, + splade_config_value(&self.api_key_env_var, self.model, self.include_tokens), + ) + } +} + +impl ChromaCloudSpladeEmbeddingFunction { + /// Returns the known embedding function name used in collection configuration. + pub fn name() -> &'static str { + SPLADE_NAME + } + + /// Returns a fluent builder for Splade embedding functions. + pub fn builder() -> ChromaCloudSpladeEmbeddingFunctionBuilder { + ChromaCloudSpladeEmbeddingFunctionBuilder::default() + } + + /// Returns a fluent builder for `chroma-cloud-splade` collection configuration. + pub fn configuration() -> ChromaCloudSpladeEmbeddingConfigurationBuilder { + ChromaCloudSpladeEmbeddingConfigurationBuilder::default() + } + + /// Returns this embedding function's serializable configuration. + pub fn get_config(&self) -> Value { + splade_config_value(&self.api_key_env_var, self.model, self.include_tokens) + } +} + +#[async_trait::async_trait] +impl EmbeddingFunction for ChromaCloudSpladeEmbeddingFunction { + type Embedding = SparseVector; + type Error = ChromaCloudEmbeddingError; + + async fn embed_strs(&self, batches: &[&str]) -> Result, Self::Error> { + if batches.is_empty() { + return Ok(Vec::new()); + } + let request = SparseEmbeddingRequest { + texts: batches, + task: "", + target: "", + fetch_tokens: if self.include_tokens { "true" } else { "false" }, + }; + let response = self + .client + .post(&self.api_url) + .json(&request) + .send() + .await? + .error_for_status()? + .json::() + .await?; + response + .embeddings + .into_iter() + .map(|embedding| embedding.into_sparse_vector(self.include_tokens)) + .collect() + } +} + +#[derive(Deserialize)] +struct QwenConfig { + model: ChromaCloudQwenEmbeddingModel, + task: Option, + api_key_env_var: Option, + instructions: Option>, +} + +#[derive(Serialize)] +struct DenseEmbeddingRequest<'a> { + instructions: &'a str, + texts: &'a [&'a str], +} + +#[derive(Deserialize)] +struct DenseEmbeddingResponse { + embeddings: Option>>, + error: Option, +} + +impl DenseEmbeddingResponse { + fn into_embeddings(self) -> Result>, ChromaCloudEmbeddingError> { + self.embeddings.ok_or_else(|| { + ChromaCloudEmbeddingError::Api( + self.error + .unwrap_or_else(|| "missing embeddings".to_string()), + ) + }) + } +} + +#[derive(Serialize)] +struct SparseEmbeddingRequest<'a> { + texts: &'a [&'a str], + task: &'a str, + target: &'a str, + fetch_tokens: &'a str, +} + +#[derive(Deserialize)] +struct SparseEmbeddingResponse { + embeddings: Vec, +} + +#[derive(Deserialize)] +struct SparseEmbedding { + indices: Vec, + values: Vec, + #[serde(default, alias = "tokens")] + labels: Option>, +} + +impl SparseEmbedding { + fn into_sparse_vector( + self, + include_tokens: bool, + ) -> Result { + if self.indices.len() != self.values.len() { + return Err(SparseVectorLengthMismatch.into()); + } + if include_tokens { + if let Some(labels) = self.labels { + if labels.len() != self.indices.len() { + return Err(SparseVectorLengthMismatch.into()); + } + let mut triples = self + .indices + .into_iter() + .zip(self.values) + .zip(labels) + .map(|((index, value), label)| (label, index, value)) + .collect::>(); + triples.sort_unstable_by_key(|(_, index, _)| *index); + return Ok(SparseVector::from_triples(triples)); + } + } + let mut pairs = self + .indices + .into_iter() + .zip(self.values) + .collect::>(); + pairs.sort_unstable_by_key(|(index, _)| *index); + Ok(SparseVector::from_pairs(pairs)) + } +} + +fn resolve_api_key( + api_key: Option, + api_key_env_var: &str, + client_api_key: Option, +) -> Result { + api_key + .or_else(|| env::var(api_key_env_var).ok()) + .or(client_api_key) + .ok_or_else(|| ChromaCloudEmbeddingError::MissingApiKey { + env_var: api_key_env_var.to_string(), + }) +} + +fn insert_qwen_instruction( + instructions: &mut HashMap, + task: impl Into, + documents: impl Into, + query: impl Into, +) { + instructions.insert( + task.into(), + ChromaCloudQwenTaskInstructions { + documents: documents.into(), + query: query.into(), + }, + ); +} + +fn known_embedding_function_configuration( + name: &str, + config: Value, +) -> EmbeddingFunctionConfiguration { + EmbeddingFunctionConfiguration::Known(EmbeddingFunctionNewConfiguration { + name: name.to_string(), + config, + }) +} + +fn qwen_config_value( + api_key_env_var: &str, + model: ChromaCloudQwenEmbeddingModel, + task: Option<&str>, + instructions: &HashMap, +) -> Value { + json!({ + "api_key_env_var": api_key_env_var, + "model": model.as_str(), + "task": task, + "instructions": instructions, + }) +} + +fn splade_config_value( + api_key_env_var: &str, + model: ChromaCloudSpladeEmbeddingModel, + include_tokens: bool, +) -> Value { + json!({ + "api_key_env_var": api_key_env_var, + "model": model.as_str(), + "include_tokens": include_tokens, + }) +} + +fn new_chroma_cloud_client( + api_key: String, + model: &str, +) -> Result { + let mut headers = HeaderMap::new(); + let mut api_key = HeaderValue::from_str(&api_key)?; + api_key.set_sensitive(true); + headers.insert(HeaderName::from_static("x-chroma-token"), api_key); + headers.insert( + HeaderName::from_static("x-chroma-embedding-model"), + HeaderValue::from_str(model)?, + ); + Ok(reqwest::Client::builder() + .default_headers(headers) + .build()?) +} + +fn chroma_embed_url_from_env() -> String { + env::var("CHROMA_EMBED_URL").unwrap_or_else(|_| DEFAULT_CHROMA_EMBED_URL.to_string()) +} + +fn trim_trailing_slash(url: String) -> String { + url.trim_end_matches('/').to_string() +} + +fn default_qwen_instructions() -> HashMap { + let mut instructions = HashMap::new(); + instructions.insert( + "nl_to_code".to_string(), + ChromaCloudQwenTaskInstructions { + documents: String::new(), + query: "Given a question about coding, retrieval code or passage that can solve user's question".to_string(), + }, + ); + instructions +} + +#[cfg(test)] +mod tests { + use super::*; + use httpmock::MockServer; + use serde_json::json; + + #[tokio::test] + async fn qwen_embeds_documents_and_queries_with_expected_instructions() { + let server = MockServer::start_async().await; + let documents = server + .mock_async(|when, then| { + when.method("POST") + .path("/") + .header("x-chroma-token", "test-api-key") + .header("x-chroma-embedding-model", "Qwen/Qwen3-Embedding-0.6B") + .json_body(json!({ + "instructions": "", + "texts": ["doc"], + })); + then.status(200) + .json_body(json!({"embeddings": [[1.0, 2.0]]})); + }) + .await; + let queries = server + .mock_async(|when, then| { + when.method("POST") + .path("/") + .json_body(json!({ + "instructions": "Given a question about coding, retrieval code or passage that can solve user's question", + "texts": ["query"], + })); + then.status(200).json_body(json!({"embeddings": [[3.0, 4.0]]})); + }) + .await; + + let embedding_function = ChromaCloudQwenEmbeddingFunction::builder() + .api_key("test-api-key") + .embed_url(server.base_url()) + .model(ChromaCloudQwenEmbeddingModel::default()) + .task("nl_to_code") + .build() + .unwrap(); + + assert_eq!( + embedding_function.embed_strs(&["doc"]).await.unwrap(), + vec![vec![1.0, 2.0]] + ); + assert_eq!( + embedding_function + .embed_query_strs(&["query"]) + .await + .unwrap(), + vec![vec![3.0, 4.0]] + ); + assert_eq!(documents.calls(), 1); + assert_eq!(queries.calls(), 1); + } + + #[tokio::test] + async fn splade_embeds_sparse_vectors() { + let server = MockServer::start_async().await; + let mock = server + .mock_async(|when, then| { + when.method("POST") + .path("/embed_sparse") + .header("x-chroma-token", "test-api-key") + .header("x-chroma-embedding-model", "prithivida/Splade_PP_en_v1") + .json_body(json!({ + "texts": ["doc"], + "task": "", + "target": "", + "fetch_tokens": "true", + })); + then.status(200).json_body(json!({ + "embeddings": [{ + "indices": [3, 1], + "values": [0.3, 0.1], + "labels": ["three", "one"], + }] + })); + }) + .await; + + let embedding_function = ChromaCloudSpladeEmbeddingFunction::builder() + .api_key("test-api-key") + .embed_url(server.base_url()) + .model(ChromaCloudSpladeEmbeddingModel::default()) + .include_tokens(true) + .build() + .unwrap(); + + let embeddings = embedding_function.embed_strs(&["doc"]).await.unwrap(); + + assert_eq!(embeddings.len(), 1); + assert_eq!(embeddings[0].indices, vec![1, 3]); + assert_eq!(embeddings[0].values, vec![0.1, 0.3]); + assert_eq!( + embeddings[0].tokens, + Some(vec!["one".to_string(), "three".to_string()]) + ); + assert_eq!(mock.calls(), 1); + } +} diff --git a/rust/chroma/src/embed/mod.rs b/rust/chroma/src/embed/mod.rs index 76da27c8a00..8c95f038be4 100644 --- a/rust/chroma/src/embed/mod.rs +++ b/rust/chroma/src/embed/mod.rs @@ -13,6 +13,8 @@ use std::{ pub mod bm25; /// Text tokenization utilities for BM25. pub mod bm25_tokenizer; +/// Chroma Cloud embedding function implementations. +pub mod chroma_cloud; /// MurmurHash3 absolute value hasher for token hashing. pub mod murmur3_abs_hasher; #[cfg(feature = "ollama")] @@ -72,6 +74,18 @@ pub trait EmbeddingFunction: Send + Sync + 'static { /// # } /// ``` async fn embed_strs(&self, batches: &[&str]) -> Result, Self::Error>; + + /// Converts query strings into embedding representations. + /// + /// Embedding models may choose to encode documents and queries differently. Implementations + /// that do not need separate query behavior can rely on the default, which delegates to + /// [`embed_strs`](Self::embed_strs). + async fn embed_query_strs( + &self, + batches: &[&str], + ) -> Result, Self::Error> { + self.embed_strs(batches).await + } } /// Generic tokenizer interface for text processing. diff --git a/rust/chroma/src/lib.rs b/rust/chroma/src/lib.rs index 16bd4992094..fa335fadd73 100644 --- a/rust/chroma/src/lib.rs +++ b/rust/chroma/src/lib.rs @@ -108,6 +108,7 @@ pub mod types; pub use client::ChromaHttpClient; pub use client::ChromaHttpClientOptions; pub use collection::ChromaCollection; +pub use collection::IntoOptionalEmbeddings; #[cfg(test)] mod tests {