diff --git a/Cargo.lock b/Cargo.lock index 41392ce..cbf6c47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -995,7 +995,7 @@ dependencies = [ "serde", "serde_json", "tokio", - "tower 0.4.13", + "tower", "url", "uuid", "yaml-rust", @@ -1852,26 +1852,6 @@ dependencies = [ "sha2", ] -[[package]] -name = "pin-project" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2150,7 +2130,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", - "tower 0.5.2", + "tower", "tower-http", "tower-service", "url", @@ -2795,23 +2775,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64" -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", - "tokio-util", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "tower" version = "0.5.2" @@ -2823,8 +2786,10 @@ dependencies = [ "pin-project-lite", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -2840,7 +2805,7 @@ dependencies = [ "http-body", "iri-string", "pin-project-lite", - "tower 0.5.2", + "tower", "tower-layer", "tower-service", ] diff --git a/Cargo.toml b/Cargo.toml index 1f90973..060bcfa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,4 @@ diesel = { version = "2.2", features = ["postgres", "r2d2", "serde_json"] } diesel_migrations = "2.2" uuid = { version = "1.18", features = ["v4"] } base64 = "0.22" -tower = { version = "0.4", features = ["limit", "util", "buffer"] } +tower = { version = "0.5", features = ["limit", "util", "buffer"] } diff --git a/src/sync/collections.rs b/src/sync/collections.rs index 2231a58..3b0a306 100644 --- a/src/sync/collections.rs +++ b/src/sync/collections.rs @@ -1,6 +1,7 @@ use super::{get_json, request}; use crate::models::{self, CollectionNew, CollectionVersionNew}; use crate::schema::collection_versions; +use crate::sync::utils::RateLimitedHttpService; use actix_web::web; use anyhow::{Context, Result}; use diesel::pg::upsert::excluded; @@ -11,13 +12,10 @@ use diesel::{ }; use futures::future::try_join_all; use log::info; -use reqwest::{Client, Request}; use serde_json::Value; use std::collections::{HashMap, HashSet}; use tokio::fs::File; use tokio::io::AsyncWriteExt; -use tower::buffer::Buffer; -use tower::limit::{ConcurrencyLimit, RateLimit}; #[allow(dead_code)] #[derive(Debug, Clone)] @@ -30,10 +28,7 @@ pub struct CollectionData { pub metadata: Value, } -pub async fn get_version( - url: String, - service: Buffer>, Request>, -) -> Result { +pub async fn get_version(url: String, service: RateLimitedHttpService) -> Result { let (service, resp) = request(url, service).await; let status = resp.status().as_str().to_string(); let json_response = resp.json::().await.unwrap(); @@ -69,34 +64,33 @@ pub async fn get_version( pub async fn sync_collections( pool: web::Data>>, response: &Value, - service: Buffer>, Request>, ) -> Result<()> { let results = response.as_object().unwrap()["data"].as_array().unwrap(); let galaxy_url = dotenv::var("GALAXY_URL").unwrap_or("https://galaxy.ansible.com/".to_string()); - let collection_version_futures: Vec<_> = results - .iter() - .map(|v| { - let nspace = v["collection_version"]["namespace"] - .as_str() - .unwrap() - .to_string(); - let n = v["collection_version"]["name"] - .as_str() - .unwrap() - .to_string(); - let vs = v["collection_version"]["version"] - .as_str() - .unwrap() - .to_string(); - get_version( - format!( - "{}api/v3/plugin/ansible/content/published/collections/index/{}/{}/versions/{}/", - galaxy_url, nspace, n, vs - ), - service.clone(), - ) - }) - .collect(); + let client = reqwest::Client::new(); + let mut collection_version_futures = Vec::new(); + for v in results.iter() { + let nspace = v["collection_version"]["namespace"] + .as_str() + .unwrap() + .to_string(); + let n = v["collection_version"]["name"] + .as_str() + .unwrap() + .to_string(); + let vs = v["collection_version"]["version"] + .as_str() + .unwrap() + .to_string(); + let service = crate::sync::utils::build_service(client.clone()); + collection_version_futures.push(get_version( + format!( + "{}api/v3/plugin/ansible/content/published/collections/index/{}/{}/versions/{}/", + galaxy_url, nspace, n, vs + ), + service, + )); + } let cversions = try_join_all(collection_version_futures) .await .context("Failed to join collection versions futures")?; @@ -167,7 +161,7 @@ pub async fn sync_collections( } pub async fn fetch_versions( - mut service: Buffer>, Request>, + mut service: RateLimitedHttpService, url: &Value, ) -> Result> { let mut versions: Vec = Vec::new(); @@ -186,19 +180,19 @@ pub async fn fetch_versions( .unwrap(); // Downloading - let collection_version_futures: Vec<_> = results - .iter() - .map(|v| { - get_version( - format!( - "{}{}", - galaxy_url.strip_suffix('/').unwrap(), - v["href"].as_str().unwrap() - ), - service.clone(), - ) - }) - .collect(); + let client = reqwest::Client::new(); + let mut collection_version_futures = Vec::new(); + for v in results.iter() { + let service = crate::sync::utils::build_service(client.clone()); + collection_version_futures.push(get_version( + format!( + "{}{}", + galaxy_url.strip_suffix('/').unwrap(), + v["href"].as_str().unwrap() + ), + service, + )); + } let cversions = try_join_all(collection_version_futures) .await .context("Failed to join collection versions futures")?; @@ -231,7 +225,6 @@ pub async fn fetch_versions( pub async fn process_collection_data( pool: web::Data>>, - service: Buffer>, Request>, data: Vec>, fetch_dependencies: bool, ) -> Result<()> { @@ -326,10 +319,12 @@ pub async fn process_collection_data( info!("Fetching collection dependencies"); let dependencies: Vec<_> = deps.keys().map(|url| get_json(url)).collect(); let deps_json = try_join_all(dependencies).await.unwrap(); - let to_fetch: Vec<_> = deps_json - .iter() - .map(|c| fetch_versions(service.clone(), &c["versions_url"])) - .collect(); + let client = reqwest::Client::new(); + let mut to_fetch = Vec::new(); + for c in deps_json.iter() { + let service = crate::sync::utils::build_service(client.clone()); + to_fetch.push(fetch_versions(service, &c["versions_url"])); + } to_process = try_join_all(to_fetch).await.unwrap(); } else { break; diff --git a/src/sync/common.rs b/src/sync/common.rs index dfdbdf7..c1ab30f 100644 --- a/src/sync/common.rs +++ b/src/sync/common.rs @@ -101,13 +101,16 @@ pub async fn process_requirements( } else { info!("Syncing collections"); let client = reqwest::Client::new(); - let service = build_service(client.clone()); - let to_fetch: Vec<_> = responses - .iter() - .map(|c| fetch_versions(service.clone(), &c["versions_url"])) - .collect(); + + // Create separate services for each fetch operation + let mut to_fetch = Vec::new(); + for c in responses.iter() { + let service = build_service(client.clone()); + to_fetch.push(fetch_versions(service, &c["versions_url"])); + } let data = try_join_all(to_fetch).await?; - process_collection_data(pool.clone(), service.clone(), data, true).await? + + process_collection_data(pool.clone(), data, true).await? }; } } @@ -141,8 +144,7 @@ pub async fn mirror_content( } else { panic!("Invalid content type!") }; - let client = reqwest::Client::new(); - let service = build_service(client.clone()); + loop { let results = get_json(target.as_str()).await.unwrap(); if content_type == "roles" { @@ -157,7 +159,7 @@ pub async fn mirror_content( .context("Failed to join next_link")? } else if content_type == "collections" { info!("Syncing collections"); - sync_collections(pool.clone(), &results, service.clone()).await?; + sync_collections(pool.clone(), &results).await?; if results.as_object().unwrap()["links"]["next"] .as_str() .is_none() diff --git a/src/sync/utils.rs b/src/sync/utils.rs index aa067ff..86292ad 100644 --- a/src/sync/utils.rs +++ b/src/sync/utils.rs @@ -2,13 +2,46 @@ use anyhow::{Context, Result}; use log::warn; use reqwest::{Client, Request, Response}; use serde_json::Value; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context as TaskContext, Poll}; use std::time::Duration; use tokio::fs::File; use tokio::io::AsyncWriteExt; use tokio::time; -use tower::buffer::Buffer; -use tower::limit::{ConcurrencyLimit, RateLimit}; -use tower::{Service, ServiceExt}; +use tower::util::BoxService; +use tower::{Service, ServiceBuilder, ServiceExt}; + +// Type alias for our rate-limited service +pub type RateLimitedHttpService = + BoxService>; + +// Wrapper around reqwest::Client to implement the Service trait +#[derive(Clone)] +pub struct HttpService { + client: Client, +} + +impl HttpService { + fn new(client: Client) -> Self { + Self { client } + } +} + +impl Service for HttpService { + type Response = Response; + type Error = reqwest::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut TaskContext<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let client = self.client.clone(); + Box::pin(async move { client.execute(req).await }) + } +} pub async fn download_tar(filename: &str, response: reqwest::Response) -> Result<()> { let mut file = match File::create(filename).await { @@ -62,7 +95,7 @@ pub async fn get_json(url: &str) -> Result { Ok(values) } -pub fn build_service(client: Client) -> Buffer>, Request> { +pub fn build_service(client: Client) -> RateLimitedHttpService { let buffer = dotenv::var("GROOT_BUFFER") .unwrap_or("100".to_string()) .as_str() @@ -77,26 +110,29 @@ pub fn build_service(client: Client) -> Buffer() .unwrap(); - tower::ServiceBuilder::new() - .buffer(buffer) - .concurrency_limit(limit) + + let http_service = HttpService::new(client); + + ServiceBuilder::new() .rate_limit(total_req, Duration::from_secs(1)) - .service(client.clone()) + .concurrency_limit(limit) + .buffer(buffer) + .service(http_service) + .boxed() } pub async fn request( url: String, - mut service: Buffer>, Request>, -) -> ( - Buffer>, Request>, - Response, -) { - let client = reqwest::Client::new(); - let http_request = client.get(url).build().unwrap(); - let mut is_ready = service.ready().await.is_ok(); - while !is_ready { - is_ready = service.ready().await.is_ok(); - } - let response = service.call(http_request).await.unwrap(); + mut service: RateLimitedHttpService, +) -> (RateLimitedHttpService, Response) { + let client = Client::new(); + let request = client.get(&url).build().unwrap(); + + // Wait for the service to be ready + let ready_service = service.ready().await.unwrap(); + + // Make the request + let response = ready_service.call(request).await.unwrap(); + (service, response) }