diff --git a/Cargo.lock b/Cargo.lock index 645cf3a3159..7281da10834 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1667,6 +1667,7 @@ name = "chroma-cli" version = "1.4.4" dependencies = [ "arboard", + "backon", "bytes", "chroma", "chroma-config", diff --git a/rust/cli/Cargo.toml b/rust/cli/Cargo.toml index 48e97368f1a..0a871366ee6 100644 --- a/rust/cli/Cargo.toml +++ b/rust/cli/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] arboard = "3.4.1" +backon = { workspace = true } bytes = { workspace = true } clap = { version = "4.5.28", features = ["derive"] } chroma = { workspace = true } diff --git a/rust/cli/src/commands/copy.rs b/rust/cli/src/commands/copy.rs index 7dc31e63ca0..7a14b56baa3 100644 --- a/rust/cli/src/commands/copy.rs +++ b/rust/cli/src/commands/copy.rs @@ -6,7 +6,8 @@ use crate::terminal::{SystemTerminal, Terminal}; use crate::utils::{ cloud_client, connect_local, CliError, ErrorResponse, LocalChromaArgs, Profile, UtilsError, }; -use chroma::client::Database; +use backon::{ExponentialBuilder, Retryable}; +use chroma::client::{ChromaHttpClientError, Database}; use chroma::ChromaHttpClient; use chroma_types::operator::Key; use chroma_types::plan::SearchPayload; @@ -14,12 +15,18 @@ use clap::Parser; use crossterm::style::Stylize; use futures::{stream, StreamExt}; use indicatif::{ProgressBar, ProgressStyle}; +use reqwest::StatusCode; use std::fmt::{self, Display}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::time::Duration; use thiserror::Error; use tokio::task::JoinHandle; +const COPY_MAX_RETRIES: usize = 5; +const COPY_INITIAL_RETRY_DELAY: Duration = Duration::from_millis(250); +const COPY_MAX_RETRY_DELAY: Duration = Duration::from_secs(5); + #[derive(Debug)] enum Environment { Local, @@ -200,6 +207,69 @@ fn get_target_and_destination( Ok((source, target)) } +fn is_not_found_error(error: &CliError) -> bool { + matches!( + error, + CliError::ChromaClient(ChromaHttpClientError::ApiError(_, status)) + if *status == StatusCode::NOT_FOUND + ) +} + +fn is_retryable_chroma_error(error: &ChromaHttpClientError) -> bool { + match error { + ChromaHttpClientError::RequestError(_) | ChromaHttpClientError::NoBackendAvailable => true, + ChromaHttpClientError::ApiError(_, status) => { + *status == StatusCode::REQUEST_TIMEOUT + || *status == StatusCode::TOO_MANY_REQUESTS + || status.is_server_error() + } + ChromaHttpClientError::CouldNotResolveDatabaseId(_) + | ChromaHttpClientError::SerdeError(_) + | ChromaHttpClientError::ValidationError(_) + | ChromaHttpClientError::InvalidWhere => false, + } +} + +fn is_retryable_copy_error(error: &CliError) -> bool { + match error { + CliError::ChromaClient(error) => is_retryable_chroma_error(error), + _ => false, + } +} + +fn copy_retry_policy() -> ExponentialBuilder { + ExponentialBuilder::new() + .with_max_times(COPY_MAX_RETRIES) + .with_min_delay(COPY_INITIAL_RETRY_DELAY) + .with_max_delay(COPY_MAX_RETRY_DELAY) +} + +async fn verify_target_collection_does_not_exist( + target: &ChromaHttpClient, + collection_name: &str, +) -> Result<(), CliError> { + let target_collection = (|| { + let target = target.clone(); + let collection_name = collection_name.to_string(); + async move { + target + .get_collection(collection_name) + .await + .map(|_| ()) + .map_err(CliError::from) + } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await; + + match target_collection { + Ok(()) => Err(CopyError::CollectionAlreadyExists(collection_name.to_string()).into()), + Err(error) if is_not_found_error(&error) => Ok(()), + Err(error) => Err(error), + } +} + async fn copy_collections( source: ChromaHttpClient, target: ChromaHttpClient, @@ -210,11 +280,34 @@ async fn copy_collections( term: &mut dyn Terminal, ) -> Result<(), CliError> { let collections = if all { - source.list_collections(10000, None).await? + (|| { + let source = source.clone(); + async move { + source + .list_collections(10000, None) + .await + .map_err(CliError::from) + } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await? } else { let mut source_collections = vec![]; for collection in collections { - let source_collection = source.get_collection(&collection).await?; + let source_collection = (|| { + let source = source.clone(); + let collection = collection.clone(); + async move { + source + .get_collection(collection) + .await + .map_err(CliError::from) + } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await?; source_collections.push(source_collection); } source_collections @@ -232,24 +325,47 @@ async fn copy_collections( term.println("Verifying collections..."); // Verify that collections don't exist on target for collection in collections.clone() { - if target.get_collection(collection.name()).await.is_ok() { - return Err(CopyError::CollectionAlreadyExists(collection.name().to_string()).into()); - } + verify_target_collection_does_not_exist(&target, collection.name()).await?; } for collection in collections { - let size = collection.count().await?; + let size = (|| { + let collection = collection.clone(); + async move { collection.count().await.map_err(CliError::from) } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await?; let offsets: Vec = (0..size).step_by(step as usize).collect(); let records_added = Arc::new(AtomicUsize::new(0)); - let target_collection = target - .create_collection( - collection.name(), - collection.schema().clone(), - collection.metadata().clone(), - ) - .await?; + let target_collection = (|| { + let target = target.clone(); + let collection_name = collection.name().to_string(); + let schema = collection.schema().clone(); + let metadata = collection.metadata().clone(); + async move { + target + .get_or_create_collection(collection_name, schema, metadata) + .await + .map_err(CliError::from) + } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await?; + + let target_size = (|| { + let target_collection = target_collection.clone(); + async move { target_collection.count().await.map_err(CliError::from) } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await?; + if target_size != 0 { + return Err(CopyError::CollectionAlreadyExists(collection.name().to_string()).into()); + } term.println(&format!("Copying collection: {}", collection.name())); @@ -274,7 +390,19 @@ async fn copy_collections( Key::Metadata, ]); - let response = collection.search(vec![search]).await?; + let response = (|| { + let collection = collection.clone(); + let search = search.clone(); + async move { + collection + .search(vec![search]) + .await + .map_err(CliError::from) + } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await?; let ids = response.ids.into_iter().next().unwrap_or_default(); if ids.is_empty() { @@ -294,18 +422,30 @@ async fn copy_collections( .collect(); let metadatas = response.metadatas.into_iter().next().flatten(); - target_collection - .add(ids, embeddings, documents, None, metadatas) - .await - .map_err(|e| { - if e.to_string().to_lowercase().contains("quota") { - let msg = serde_json::from_str::(&e.to_string()) - .unwrap_or_default() - .message; - return CliError::Utils(UtilsError::Quota(msg)); - } - CliError::ChromaClient(e) - })?; + (|| { + let target_collection = target_collection.clone(); + let ids = ids.clone(); + let embeddings = embeddings.clone(); + let documents = documents.clone(); + let metadatas = metadatas.clone(); + async move { + target_collection + .add(ids, embeddings, documents, None, metadatas) + .await + .map_err(|e| { + if e.to_string().to_lowercase().contains("quota") { + let msg = serde_json::from_str::(&e.to_string()) + .unwrap_or_default() + .message; + return CliError::Utils(UtilsError::Quota(msg)); + } + CliError::ChromaClient(e) + }) + } + }) + .retry(copy_retry_policy()) + .when(is_retryable_copy_error) + .await?; let current_added = records_added.fetch_add(num_records, Ordering::Relaxed) + num_records; @@ -377,6 +517,40 @@ mod tests { } } + fn api_error(status: StatusCode) -> CliError { + CliError::ChromaClient(ChromaHttpClientError::ApiError( + "request failed".to_string(), + status, + )) + } + + #[test] + fn test_is_retryable_copy_error_recognizes_transient_errors() { + assert!(is_retryable_copy_error(&api_error( + StatusCode::REQUEST_TIMEOUT + ))); + assert!(is_retryable_copy_error(&api_error( + StatusCode::TOO_MANY_REQUESTS + ))); + assert!(is_retryable_copy_error(&api_error( + StatusCode::INTERNAL_SERVER_ERROR + ))); + assert!(is_retryable_copy_error(&CliError::ChromaClient( + ChromaHttpClientError::NoBackendAvailable, + ))); + } + + #[test] + fn test_is_retryable_copy_error_rejects_deterministic_errors() { + assert!(!is_retryable_copy_error(&api_error( + StatusCode::BAD_REQUEST + ))); + assert!(!is_retryable_copy_error(&api_error(StatusCode::NOT_FOUND))); + assert!(!is_retryable_copy_error(&CliError::Copy( + CopyError::NoCollections, + ))); + } + #[test] fn test_get_target_and_destination_from_cloud() { let mut args = default_args();