diff --git a/crates/Cargo.lock b/crates/Cargo.lock index e8b1bbb932..310a865da0 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -182,6 +182,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "block2" version = "0.6.2" @@ -243,6 +252,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures", + "rand_core 0.10.1", +] + [[package]] name = "chrono" version = "0.4.44" @@ -295,19 +315,23 @@ version = "0.21.0-rc.3" dependencies = [ "async-trait", "base64", + "bytes", "coglet", "futures", "libc", + "object_store", "pyo3", "pyo3-async-runtimes", "pyo3-stub-gen", "sentry", "serde_json", "tempfile", + "thiserror 2.0.18", "tokio", "tokio-util", "tracing", "tracing-subscriber", + "url", ] [[package]] @@ -376,6 +400,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -388,6 +421,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -439,6 +482,16 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dispatch2" version = "0.3.1" @@ -503,7 +556,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -659,6 +712,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.24" @@ -675,8 +738,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -686,9 +751,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi 5.3.0", "wasip2", + "wasm-bindgen", ] [[package]] @@ -700,6 +767,7 @@ dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", + "rand_core 0.10.1", "wasip2", "wasip3", ] @@ -824,6 +892,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + [[package]] name = "hyper" version = "1.8.1" @@ -857,6 +931,7 @@ dependencies = [ "hyper", "hyper-util", "rustls", + "rustls-native-certs", "rustls-pki-types", "tokio", "tokio-rustls", @@ -1249,6 +1324,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "maplit" version = "1.0.2" @@ -1280,6 +1361,16 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "memchr" version = "2.8.0" @@ -1658,6 +1749,45 @@ dependencies = [ "memchr", ] +[[package]] +name = "object_store" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622acbc9100d3c10e2ee15804b0caa40e55c933d5aa53814cd520805b7958a49" +dependencies = [ + "async-trait", + "base64", + "bytes", + "chrono", + "form_urlencoded", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body-util", + "httparse", + "humantime", + "hyper", + "itertools 0.14.0", + "md-5", + "parking_lot", + "percent-encoding", + "quick-xml", + "rand 0.10.1", + "reqwest 0.12.28", + "ring", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "thiserror 2.0.18", + "tokio", + "tracing", + "url", + "wasm-bindgen-futures", + "web-time", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -1952,6 +2082,71 @@ dependencies = [ "syn", ] +[[package]] +name = "quick-xml" +version = "0.39.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdcc8dd4e2f670d309a5f0e83fe36dfdc05af317008fea29144da1a2ac858e5e" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.1.1", + "rustls", + "socket2", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash 2.1.1", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.45" @@ -1994,6 +2189,17 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core 0.10.1", +] + [[package]] name = "rand_chacha" version = "0.3.1" @@ -2032,6 +2238,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + [[package]] name = "rawpointer" version = "0.2.1" @@ -2121,26 +2333,35 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", + "h2", "http", "http-body", "http-body-util", "hyper", + "hyper-rustls", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", + "rustls-native-certs", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", ] @@ -2234,7 +2455,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2270,6 +2491,7 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ + "web-time", "zeroize", ] @@ -2291,7 +2513,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2727,7 +2949,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2829,6 +3051,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.52.3" @@ -3046,6 +3283,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" + [[package]] name = "uname" version = "0.1.1" @@ -3387,6 +3630,19 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.244.0" @@ -3409,6 +3665,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-root-certs" version = "1.0.6" @@ -3433,7 +3699,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/crates/coglet-python/Cargo.toml b/crates/coglet-python/Cargo.toml index 9cfd504968..0196c9a419 100644 --- a/crates/coglet-python/Cargo.toml +++ b/crates/coglet-python/Cargo.toml @@ -13,16 +13,21 @@ crate-type = ["cdylib", "rlib"] [dependencies] async-trait = "0.1.89" base64 = "0.22" +bytes = "1" coglet_core = { path = "../coglet", package = "coglet" } futures.workspace = true +object_store = { version = "0.13", default-features = false, features = ["aws", "gcp", "azure"] } pyo3.workspace = true pyo3-async-runtimes.workspace = true pyo3-stub-gen.workspace = true serde_json.workspace = true +tempfile = "3" +thiserror.workspace = true tokio.workspace = true tokio-util = { workspace = true, features = ["codec"] } tracing.workspace = true tracing-subscriber.workspace = true +url = "2" sentry.workspace = true [target.'cfg(unix)'.dependencies] @@ -30,7 +35,6 @@ libc = "0.2" [dev-dependencies] pyo3 = { workspace = true, features = ["auto-initialize"] } -tempfile = "3" [features] extension-module = ["pyo3/extension-module"] diff --git a/crates/coglet-python/src/cloud.rs b/crates/coglet-python/src/cloud.rs new file mode 100644 index 0000000000..78ad831738 --- /dev/null +++ b/crates/coglet-python/src/cloud.rs @@ -0,0 +1,312 @@ +//! Native cloud object-storage downloads for predictor inputs. +//! +//! Detects cloud-scheme URLs (`s3://`, `gs://`, `az://`) in input values and +//! downloads them to temp files using the `object_store` crate, so the rest of +//! the input pipeline can treat them as ordinary local paths. + +use std::io::Write as _; +use std::sync::Arc; + +use object_store::path::Path as ObjectPath; +use object_store::{ObjectStore, ObjectStoreExt}; + +/// Errors that can occur while resolving and downloading a cloud object. +#[derive(Debug, thiserror::Error)] +pub enum CloudError { + #[error("invalid cloud url '{0}'")] + InvalidUrl(String), + #[error("failed to build object store for '{url}': {source}")] + Store { + url: String, + #[source] + source: object_store::Error, + }, + #[error("failed to download '{url}': {source}")] + Download { + url: String, + #[source] + source: object_store::Error, + }, + #[error("failed to write temp file for '{url}': {source}")] + Io { + url: String, + #[source] + source: std::io::Error, + }, +} + +/// Build an object store for the given cloud URL, reading credentials and +/// endpoint configuration from the process environment, and return the store +/// together with the in-bucket object path to fetch. +/// +/// Credentials use each provider's standard environment variables, resolved by +/// `object_store`'s builders via `parse_url_opts`: +/// - S3 / R2 / MinIO: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, +/// `AWS_SESSION_TOKEN`, `AWS_REGION`, `AWS_ENDPOINT_URL` (set endpoint for R2/MinIO). +/// - GCS: `GOOGLE_SERVICE_ACCOUNT` / `GOOGLE_APPLICATION_CREDENTIALS`. +/// - Azure: `AZURE_STORAGE_ACCOUNT_NAME`, `AZURE_STORAGE_ACCOUNT_KEY`, etc. +pub fn build_store_for_url(url: &str) -> Result<(Arc, ObjectPath), CloudError> { + let parsed = url::Url::parse(url).map_err(|_| CloudError::InvalidUrl(url.to_string()))?; + // `object_store::parse_url_opts` builds the right store from the URL scheme + // and an iterator of (config-key, value) options, ignoring keys a given + // provider does not recognize. Passing the whole process environment lets + // each provider pick up its standard credentials (verified against + // object_store 0.13.2: signature is `I: IntoIterator, + // K: AsRef, V: Into`, so `Vec<(String, String)>` fits). + // + // Offline construction is fine: S3 defaults its region to `us-east-1` and + // GCS falls back to a lazy instance-credential provider, so neither builder + // performs network I/O at `build()` time. + let opts: Vec<(String, String)> = std::env::vars().collect(); + let (store, path) = + object_store::parse_url_opts(&parsed, opts).map_err(|source| CloudError::Store { + url: url.to_string(), + source, + })?; + Ok((Arc::from(store), path)) +} + +/// Async: fetch the full object body for an already-built store + path. +/// Composable so callers can run many fetches concurrently under one runtime. +async fn fetch_bytes( + store: &dyn ObjectStore, + path: &ObjectPath, + url: &str, +) -> Result { + let get_result = store + .get(path) + .await + .map_err(|source| CloudError::Download { + url: url.to_string(), + source, + })?; + get_result + .bytes() + .await + .map_err(|source| CloudError::Download { + url: url.to_string(), + source, + }) +} + +/// Write already-fetched bytes to a uniquely-named temp file, preserving the +/// suggested filename as the suffix (so the extension survives). Returns the +/// temp file path; cleanup happens later via PreparedInput's drop calling +/// unlink() on the path we hand to Python. +fn write_temp( + url: &str, + suggested_filename: &str, + data: &[u8], +) -> Result { + let suffix = sanitize_suffix(suggested_filename); + let mut temp = tempfile::Builder::new() + .suffix(&suffix) + .tempfile() + .map_err(|source| CloudError::Io { + url: url.to_string(), + source, + })?; + temp.write_all(data).map_err(|source| CloudError::Io { + url: url.to_string(), + source, + })?; + let (_file, pathbuf) = temp.keep().map_err(|e| CloudError::Io { + url: url.to_string(), + source: e.error, + })?; + Ok(pathbuf) +} + +/// Download MANY cloud URLs to temp files concurrently, returning the local +/// temp paths in the SAME ORDER as the input `urls`. +/// +/// This is synchronous (blocks on a single tokio runtime) but performs all +/// network transfers concurrently via `try_join_all`. If any download fails, +/// the whole call fails (first error wins) and already-written temp files are +/// best-effort removed. Callers holding the GIL should wrap this in +/// `py.allow_threads(...)`. +pub fn download_many_to_temp(urls: &[String]) -> Result, CloudError> { + if urls.is_empty() { + return Ok(Vec::new()); + } + // Build a store + object path for each URL up front (offline, no network). + let mut stores: Vec<(Arc, ObjectPath, String, String)> = + Vec::with_capacity(urls.len()); + for url in urls { + let (store, path) = build_store_for_url(url)?; + let filename = url.rsplit('/').next().unwrap_or("file").to_string(); + stores.push((store, path, url.clone(), filename)); + } + + let runtime = pyo3_async_runtimes::tokio::get_runtime(); + let bodies: Vec = runtime.block_on(async { + let fetches = stores + .iter() + .map(|(store, path, url, _)| fetch_bytes(store.as_ref(), path, url)); + futures::future::try_join_all(fetches).await + })?; + + // Write each body to a temp file, preserving order. Roll back on error. + let mut written: Vec = Vec::with_capacity(bodies.len()); + for (body, (_, _, url, filename)) in bodies.iter().zip(stores.iter()) { + match write_temp(url, filename, body) { + Ok(p) => written.push(p), + Err(e) => { + for p in &written { + std::fs::remove_file(p).ok(); + } + return Err(e); + } + } + } + Ok(written) +} + +/// Build a filesystem-safe temp-file suffix from a suggested filename. +fn sanitize_suffix(name: &str) -> String { + let base = name.rsplit('/').next().unwrap_or(name); + if base.is_empty() { + return String::new(); + } + format!("-{}", base.replace(['\0', '/'], "_")) +} + +/// Returns true if `s` is a cloud object-storage URL that this module can +/// download (`s3://`, `gs://`, `az://`/`azure://`). +/// +/// Note: Cloudflare R2 and MinIO are S3-compatible and have NO scheme of their +/// own. They are addressed with the `s3://` scheme and reached by setting +/// `AWS_ENDPOINT_URL` (and `AWS_REGION=auto` for R2) in the environment, which +/// `build_store_for_url` picks up. So there is intentionally no `r2://` here. +pub fn is_cloud_url(s: &str) -> bool { + s.starts_with("s3://") + || s.starts_with("gs://") + || s.starts_with("az://") + || s.starts_with("azure://") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detects_supported_schemes() { + assert!(is_cloud_url("s3://bucket/key.png")); + assert!(is_cloud_url("gs://bucket/key.png")); + assert!(is_cloud_url("az://container/key.png")); + assert!(is_cloud_url("azure://container/key.png")); + } + + #[test] + fn r2_and_minio_use_s3_scheme() { + // R2/MinIO have no dedicated scheme; they are addressed via s3://. + assert!(is_cloud_url("s3://my-r2-bucket/inputs/img.png")); + } + + #[test] + fn rejects_non_cloud_schemes() { + assert!(!is_cloud_url("https://example.com/x.png")); + assert!(!is_cloud_url("http://example.com/x.png")); + assert!(!is_cloud_url("data:image/png;base64,AAAA")); + assert!(!is_cloud_url("/local/path.png")); + assert!(!is_cloud_url("file.png")); + } + + #[test] + fn builds_s3_store_and_extracts_path() { + // No real network call — building the store and parsing the path is offline. + let (_store, path) = + build_store_for_url("s3://my-bucket/inputs/img.png").expect("should build s3 store"); + assert_eq!(path.as_ref(), "inputs/img.png"); + } + + #[test] + fn builds_gcs_store_and_extracts_path() { + let (_store, path) = + build_store_for_url("gs://my-bucket/a/b/c.jpg").expect("should build gcs store"); + assert_eq!(path.as_ref(), "a/b/c.jpg"); + } + + #[test] + fn rejects_unparseable_url() { + let err = build_store_for_url("s3://").err(); + assert!(err.is_some(), "empty bucket url should error"); + } + + #[test] + fn fetch_and_write_temp_roundtrips() { + use object_store::ObjectStoreExt as _; + use object_store::memory::InMemory; + + let runtime = pyo3_async_runtimes::tokio::get_runtime(); + let store = InMemory::new(); + let obj_path = ObjectPath::from("inputs/hello.txt"); + runtime + .block_on(store.put(&obj_path, b"hello cloud".to_vec().into())) + .expect("seed object"); + + let body = runtime + .block_on(fetch_bytes( + &store, + &obj_path, + "s3://bucket/inputs/hello.txt", + )) + .expect("fetch should succeed"); + let temp = write_temp("s3://bucket/inputs/hello.txt", "hello.txt", &body) + .expect("write should succeed"); + + let contents = std::fs::read(&temp).expect("temp file should exist"); + assert_eq!(contents, b"hello cloud"); + assert!( + temp.to_string_lossy().ends_with("hello.txt"), + "temp file should preserve filename suffix, got {temp:?}" + ); + std::fs::remove_file(&temp).ok(); + } + + #[test] + fn parallel_download_preserves_order() { + // Seed three objects in a single InMemory store, then fetch them + // concurrently via try_join_all and assert ordering. This is the + // ordering contract for download_many_to_temp (which builds a fresh + // store per URL from env, so it cannot reuse this single InMemory). + use object_store::ObjectStoreExt as _; + use object_store::memory::InMemory; + + let runtime = pyo3_async_runtimes::tokio::get_runtime(); + let store = InMemory::new(); + for (i, name) in ["a.txt", "b.txt", "c.txt"].iter().enumerate() { + let p = ObjectPath::from(format!("in/{name}")); + runtime + .block_on(store.put(&p, format!("body-{i}").into_bytes().into())) + .expect("seed"); + } + let paths = [ + ObjectPath::from("in/a.txt"), + ObjectPath::from("in/b.txt"), + ObjectPath::from("in/c.txt"), + ]; + let bodies = runtime + .block_on(async { + let fs = paths + .iter() + .map(|p| fetch_bytes(&store, p, "s3://bucket/x")); + futures::future::try_join_all(fs).await + }) + .expect("all fetches succeed"); + assert_eq!(bodies[0].as_ref(), b"body-0"); + assert_eq!(bodies[1].as_ref(), b"body-1"); + assert_eq!(bodies[2].as_ref(), b"body-2"); + } + + #[test] + fn download_many_empty_is_ok() { + assert!(download_many_to_temp(&[]).expect("empty ok").is_empty()); + } + + #[test] + fn sanitize_suffix_strips_path_and_nulls() { + assert_eq!(sanitize_suffix("a/b/c.png"), "-c.png"); + assert_eq!(sanitize_suffix(""), ""); + } +} diff --git a/crates/coglet-python/src/input.rs b/crates/coglet-python/src/input.rs index 0c23336495..9547bdd068 100644 --- a/crates/coglet-python/src/input.rs +++ b/crates/coglet-python/src/input.rs @@ -79,8 +79,9 @@ pub fn prepare_input( func: &Bound<'_, PyAny>, ) -> PyResult { let fields = classify_fields(py, func)?; + let mut cleanup_paths = download_cloud_inputs_into_dict(py, input, &fields)?; coerce_url_strings(py, input, &fields)?; - let cleanup_paths = download_url_paths_into_dict(py, input)?; + cleanup_paths.extend(download_url_paths_into_dict(py, input)?); Ok(PreparedInput::new(input.clone().unbind(), cleanup_paths)) } @@ -233,6 +234,104 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult }, + /// payload[key] is a list; the URL is at list index `idx`. + ListItem { key: Py, idx: usize }, +} + +/// Download cloud-storage URLs (`s3://`, `gs://`, `az://`) for File/Path fields +/// to temp files, replacing each dict value with the local path string. +/// +/// Runs BEFORE `coerce_url_strings` so that by the time the rest of the +/// pipeline runs, cloud inputs look like ordinary local paths. All cloud +/// downloads happen CONCURRENTLY (see `cloud::download_many_to_temp`), matching +/// the parallelism of the existing Python ThreadPoolExecutor HTTP path. +/// +/// Returns the Python `pathlib.Path` objects (one per downloaded file) for +/// cleanup on drop. +fn download_cloud_inputs_into_dict( + py: Python<'_>, + payload: &Bound<'_, PyDict>, + fields: &FieldClassification, +) -> PyResult> { + // Phase 1: collect all cloud URLs and where each one lives. + let mut urls: Vec = Vec::new(); + let mut slots: Vec = Vec::new(); + + for (key, value) in payload.iter() { + let key_str: String = key.extract().unwrap_or_default(); + let is_file_or_path = + fields.file_fields.contains(&key_str) || fields.path_fields.contains(&key_str); + if !is_file_or_path { + continue; + } + + if let Ok(s) = value.extract::() { + if crate::cloud::is_cloud_url(&s) { + urls.push(s); + slots.push(CloudSlot::Single { + key: key.clone().unbind(), + }); + } + } else if let Ok(list) = value.extract::>() { + for (idx, item) in list.iter().enumerate() { + if let Ok(s) = item.extract::() + && crate::cloud::is_cloud_url(&s) + { + urls.push(s); + slots.push(CloudSlot::ListItem { + key: key.clone().unbind(), + idx, + }); + } + } + } + } + + if urls.is_empty() { + return Ok(Vec::new()); + } + + // Phase 2: download all URLs concurrently, releasing the GIL meanwhile. + // Returned paths are in the SAME ORDER as `urls`/`slots`. + let local_paths = py + .detach(|| crate::cloud::download_many_to_temp(&urls)) + .map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("cloud download failed: {e}")) + })?; + + // Phase 3: splice local paths back into the payload and build cleanup list. + let pathlib = py.import("pathlib")?; + let path_class = pathlib.getattr("Path")?; + let mut cleanup: Vec = Vec::with_capacity(local_paths.len()); + + for (slot, local) in slots.iter().zip(local_paths.iter()) { + let local_str = local.to_string_lossy().to_string(); + cleanup.push(path_class.call1((&local_str,))?.unbind()); + let new_val = pyo3::types::PyString::new(py, &local_str); + match slot { + CloudSlot::Single { key } => { + payload.set_item(key.bind(py), &new_val)?; + } + CloudSlot::ListItem { key, idx } => { + let key = key.bind(py); + let item = payload.get_item(key)?.ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err(format!( + "Input key '{key}' disappeared during cloud download" + )) + })?; + let list = item.extract::>()?; + list.set_item(*idx, &new_val)?; + } + } + } + Ok(cleanup) +} + /// Coerce URL string values in the input dict to the appropriate cog types. /// /// After `json.loads()`, all values are plain Python types. URL strings diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index 35130783ba..c05de99384 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -2,6 +2,7 @@ mod audit; mod cancel; +mod cloud; mod input; mod log_writer; mod metric_scope; diff --git a/docs/http.md b/docs/http.md index 81fa737d13..f0f8948f23 100644 --- a/docs/http.md +++ b/docs/http.md @@ -273,6 +273,77 @@ This produces a random identifier that is 26 ASCII characters long. 'wjx3whax6rf4vphkegkhcvpv6a' ``` +## File inputs + +A model's `run` function can accept file input through +[`cog.Path`](python.md#cogpath) or [`cog.File`](python.md#cogfile-deprecated) +parameters. In the request body, these inputs are passed as URLs: + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 + +{ + "input": {"image": "https://example.com/image.jpg"} +} +``` + +The following URL schemes are supported: + +- `data:` — a [data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs) + carrying the file contents inline (for example `data:image/png;base64,...`). +- `http:` / `https:` — the server downloads the file over HTTP(S). +- `s3:`, `gs:`, `az:` — the server downloads the file directly from cloud + object storage (see [Cloud storage inputs](#cloud-storage-inputs) below). + +Before your `run` function is called, the server resolves each file input to a +local file on disk and passes that path (or file handle) to your model. Your +model never sees the original URL. + +### Cloud storage inputs + +File inputs can reference objects in cloud object storage. Cog downloads them +natively before the prediction runs. Supported schemes: + +- Amazon S3: `s3://bucket/key` +- Google Cloud Storage: `gs://bucket/key` +- Azure Blob Storage: `az://container/key` (also accepts `azure://container/key`) + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 + +{ + "input": {"image": "s3://my-bucket/inputs/cat.png"} +} +``` + +Credentials are read from the standard provider environment variables in the +model container — there is no Cog-specific configuration: + +- Amazon S3: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, + `AWS_SESSION_TOKEN`, `AWS_REGION`. +- Google Cloud Storage: `GOOGLE_SERVICE_ACCOUNT` or + `GOOGLE_APPLICATION_CREDENTIALS`. +- Azure Blob Storage: `AZURE_STORAGE_ACCOUNT_NAME`, + `AZURE_STORAGE_ACCOUNT_KEY`. + +S3-compatible stores such as **Cloudflare R2** and **MinIO** use the `s3://` +scheme together with a custom endpoint set via `AWS_ENDPOINT_URL`. For +Cloudflare R2, also set `AWS_REGION=auto`. + +> [!NOTE] +> Cloud storage is currently supported for **inputs only** (downloading). +> Uploading file outputs to cloud storage is a separate feature planned for a +> future release; for now, file outputs are returned as data URLs or uploaded +> over HTTP (see [File uploads](#file-uploads)). + +> [!NOTE] +> Cloud `cog.File` inputs are downloaded eagerly (the file is fetched before +> your `run` function is called), whereas `http:`/`https:` `cog.File` inputs are +> streamed lazily on first read. `cog.Path` inputs are always downloaded +> eagerly regardless of scheme. + ## File uploads A model's `run` function can produce file output by yielding or returning diff --git a/docs/llms.txt b/docs/llms.txt index 2044b2f6d7..e89b9ad960 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -1578,6 +1578,77 @@ This produces a random identifier that is 26 ASCII characters long. 'wjx3whax6rf4vphkegkhcvpv6a' ``` +## File inputs + +A model's `run` function can accept file input through +[`cog.Path`](python.md#cogpath) or [`cog.File`](python.md#cogfile-deprecated) +parameters. In the request body, these inputs are passed as URLs: + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 + +{ + "input": {"image": "https://example.com/image.jpg"} +} +``` + +The following URL schemes are supported: + +- `data:` — a [data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs) + carrying the file contents inline (for example `data:image/png;base64,...`). +- `http:` / `https:` — the server downloads the file over HTTP(S). +- `s3:`, `gs:`, `az:` — the server downloads the file directly from cloud + object storage (see [Cloud storage inputs](#cloud-storage-inputs) below). + +Before your `run` function is called, the server resolves each file input to a +local file on disk and passes that path (or file handle) to your model. Your +model never sees the original URL. + +### Cloud storage inputs + +File inputs can reference objects in cloud object storage. Cog downloads them +natively before the prediction runs. Supported schemes: + +- Amazon S3: `s3://bucket/key` +- Google Cloud Storage: `gs://bucket/key` +- Azure Blob Storage: `az://container/key` (also accepts `azure://container/key`) + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 + +{ + "input": {"image": "s3://my-bucket/inputs/cat.png"} +} +``` + +Credentials are read from the standard provider environment variables in the +model container — there is no Cog-specific configuration: + +- Amazon S3: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, + `AWS_SESSION_TOKEN`, `AWS_REGION`. +- Google Cloud Storage: `GOOGLE_SERVICE_ACCOUNT` or + `GOOGLE_APPLICATION_CREDENTIALS`. +- Azure Blob Storage: `AZURE_STORAGE_ACCOUNT_NAME`, + `AZURE_STORAGE_ACCOUNT_KEY`. + +S3-compatible stores such as **Cloudflare R2** and **MinIO** use the `s3://` +scheme together with a custom endpoint set via `AWS_ENDPOINT_URL`. For +Cloudflare R2, also set `AWS_REGION=auto`. + +> [!NOTE] +> Cloud storage is currently supported for **inputs only** (downloading). +> Uploading file outputs to cloud storage is a separate feature planned for a +> future release; for now, file outputs are returned as data URLs or uploaded +> over HTTP (see [File uploads](#file-uploads)). + +> [!NOTE] +> Cloud `cog.File` inputs are downloaded eagerly (the file is fetched before +> your `run` function is called), whereas `http:`/`https:` `cog.File` inputs are +> streamed lazily on first read. `cog.Path` inputs are always downloaded +> eagerly regardless of scheme. + ## File uploads A model's `run` function can produce file output by yielding or returning