Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[dependencies]
arboard = "3.4.1"
backon = { workspace = true }
bytes = { workspace = true }
clap = { version = "4.5.28", features = ["derive"] }
chroma = { workspace = true }
Expand Down
228 changes: 201 additions & 27 deletions rust/cli/src/commands/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,27 @@ use crate::terminal::{SystemTerminal, Terminal};
use crate::utils::{
cloud_client, connect_local, CliError, ErrorResponse, LocalChromaArgs, Profile, UtilsError,
};
use chroma::client::Database;
use backon::{ExponentialBuilder, Retryable};
use chroma::client::{ChromaHttpClientError, Database};
use chroma::ChromaHttpClient;
use chroma_types::operator::Key;
use chroma_types::plan::SearchPayload;
use clap::Parser;
use crossterm::style::Stylize;
use futures::{stream, StreamExt};
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::StatusCode;
use std::fmt::{self, Display};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::task::JoinHandle;

const COPY_MAX_RETRIES: usize = 5;
const COPY_INITIAL_RETRY_DELAY: Duration = Duration::from_millis(250);
const COPY_MAX_RETRY_DELAY: Duration = Duration::from_secs(5);

#[derive(Debug)]
enum Environment {
Local,
Expand Down Expand Up @@ -200,6 +207,69 @@ fn get_target_and_destination(
Ok((source, target))
}

fn is_not_found_error(error: &CliError) -> bool {
matches!(
error,
CliError::ChromaClient(ChromaHttpClientError::ApiError(_, status))
if *status == StatusCode::NOT_FOUND
)
}

fn is_retryable_chroma_error(error: &ChromaHttpClientError) -> bool {
match error {
ChromaHttpClientError::RequestError(_) | ChromaHttpClientError::NoBackendAvailable => true,
ChromaHttpClientError::ApiError(_, status) => {
*status == StatusCode::REQUEST_TIMEOUT
|| *status == StatusCode::TOO_MANY_REQUESTS
|| status.is_server_error()
}
ChromaHttpClientError::CouldNotResolveDatabaseId(_)
| ChromaHttpClientError::SerdeError(_)
| ChromaHttpClientError::ValidationError(_)
| ChromaHttpClientError::InvalidWhere => false,
}
}

fn is_retryable_copy_error(error: &CliError) -> bool {
match error {
CliError::ChromaClient(error) => is_retryable_chroma_error(error),
_ => false,
}
}

fn copy_retry_policy() -> ExponentialBuilder {
ExponentialBuilder::new()
.with_max_times(COPY_MAX_RETRIES)
.with_min_delay(COPY_INITIAL_RETRY_DELAY)
.with_max_delay(COPY_MAX_RETRY_DELAY)
}

async fn verify_target_collection_does_not_exist(
target: &ChromaHttpClient,
collection_name: &str,
) -> Result<(), CliError> {
let target_collection = (|| {
let target = target.clone();
let collection_name = collection_name.to_string();
async move {
target
.get_collection(collection_name)
.await
.map(|_| ())
.map_err(CliError::from)
}
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await;

match target_collection {
Ok(()) => Err(CopyError::CollectionAlreadyExists(collection_name.to_string()).into()),
Err(error) if is_not_found_error(&error) => Ok(()),
Err(error) => Err(error),
}
}

async fn copy_collections(
source: ChromaHttpClient,
target: ChromaHttpClient,
Expand All @@ -210,11 +280,34 @@ async fn copy_collections(
term: &mut dyn Terminal,
) -> Result<(), CliError> {
let collections = if all {
source.list_collections(10000, None).await?
(|| {
let source = source.clone();
async move {
source
.list_collections(10000, None)
.await
.map_err(CliError::from)
}
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await?
} else {
let mut source_collections = vec![];
for collection in collections {
let source_collection = source.get_collection(&collection).await?;
let source_collection = (|| {
let source = source.clone();
let collection = collection.clone();
async move {
source
.get_collection(collection)
.await
.map_err(CliError::from)
}
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await?;
source_collections.push(source_collection);
}
source_collections
Expand All @@ -232,24 +325,47 @@ async fn copy_collections(
term.println("Verifying collections...");
// Verify that collections don't exist on target
for collection in collections.clone() {
if target.get_collection(collection.name()).await.is_ok() {
return Err(CopyError::CollectionAlreadyExists(collection.name().to_string()).into());
}
verify_target_collection_does_not_exist(&target, collection.name()).await?;
}

for collection in collections {
let size = collection.count().await?;
let size = (|| {
let collection = collection.clone();
async move { collection.count().await.map_err(CliError::from) }
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await?;

let offsets: Vec<u32> = (0..size).step_by(step as usize).collect();
let records_added = Arc::new(AtomicUsize::new(0));

let target_collection = target
.create_collection(
collection.name(),
collection.schema().clone(),
collection.metadata().clone(),
)
.await?;
let target_collection = (|| {
let target = target.clone();
let collection_name = collection.name().to_string();
let schema = collection.schema().clone();
let metadata = collection.metadata().clone();
async move {
target
.get_or_create_collection(collection_name, schema, metadata)
.await
.map_err(CliError::from)
}
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await?;

let target_size = (|| {
let target_collection = target_collection.clone();
async move { target_collection.count().await.map_err(CliError::from) }
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await?;
if target_size != 0 {
return Err(CopyError::CollectionAlreadyExists(collection.name().to_string()).into());
}

term.println(&format!("Copying collection: {}", collection.name()));

Expand All @@ -274,7 +390,19 @@ async fn copy_collections(
Key::Metadata,
]);

let response = collection.search(vec![search]).await?;
let response = (|| {
let collection = collection.clone();
let search = search.clone();
async move {
collection
.search(vec![search])
.await
.map_err(CliError::from)
}
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await?;

let ids = response.ids.into_iter().next().unwrap_or_default();
if ids.is_empty() {
Expand All @@ -294,18 +422,30 @@ async fn copy_collections(
.collect();
let metadatas = response.metadatas.into_iter().next().flatten();

target_collection
.add(ids, embeddings, documents, None, metadatas)
.await
.map_err(|e| {
if e.to_string().to_lowercase().contains("quota") {
let msg = serde_json::from_str::<ErrorResponse>(&e.to_string())
.unwrap_or_default()
.message;
return CliError::Utils(UtilsError::Quota(msg));
}
CliError::ChromaClient(e)
})?;
(|| {
let target_collection = target_collection.clone();
let ids = ids.clone();
let embeddings = embeddings.clone();
let documents = documents.clone();
let metadatas = metadatas.clone();
async move {
target_collection
.add(ids, embeddings, documents, None, metadatas)
.await
.map_err(|e| {
if e.to_string().to_lowercase().contains("quota") {
let msg = serde_json::from_str::<ErrorResponse>(&e.to_string())
.unwrap_or_default()
.message;
return CliError::Utils(UtilsError::Quota(msg));
}
CliError::ChromaClient(e)
})
}
})
.retry(copy_retry_policy())
.when(is_retryable_copy_error)
.await?;

let current_added =
records_added.fetch_add(num_records, Ordering::Relaxed) + num_records;
Expand Down Expand Up @@ -377,6 +517,40 @@ mod tests {
}
}

fn api_error(status: StatusCode) -> CliError {
CliError::ChromaClient(ChromaHttpClientError::ApiError(
"request failed".to_string(),
status,
))
}

#[test]
fn test_is_retryable_copy_error_recognizes_transient_errors() {
assert!(is_retryable_copy_error(&api_error(
StatusCode::REQUEST_TIMEOUT
)));
assert!(is_retryable_copy_error(&api_error(
StatusCode::TOO_MANY_REQUESTS
)));
assert!(is_retryable_copy_error(&api_error(
StatusCode::INTERNAL_SERVER_ERROR
)));
assert!(is_retryable_copy_error(&CliError::ChromaClient(
ChromaHttpClientError::NoBackendAvailable,
)));
}

#[test]
fn test_is_retryable_copy_error_rejects_deterministic_errors() {
assert!(!is_retryable_copy_error(&api_error(
StatusCode::BAD_REQUEST
)));
assert!(!is_retryable_copy_error(&api_error(StatusCode::NOT_FOUND)));
assert!(!is_retryable_copy_error(&CliError::Copy(
CopyError::NoCollections,
)));
}

#[test]
fn test_get_target_and_destination_from_cloud() {
let mut args = default_args();
Expand Down
Loading