diff --git a/rust/numaflow-core/src/mapper/map.rs b/rust/numaflow-core/src/mapper/map.rs index 3fb4c4af5c..4e0b42842b 100644 --- a/rust/numaflow-core/src/mapper/map.rs +++ b/rust/numaflow-core/src/mapper/map.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; use std::sync::Arc; -use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; use chrono::{DateTime, Utc}; @@ -17,7 +16,7 @@ use tracing::{error, info, warn}; use crate::config::pipeline::map::MapMode; use crate::config::{get_vertex_name, is_mono_vertex}; use crate::error::{self, Error}; -use crate::message::{AckHandle, Message, MessageID, Offset}; +use crate::message::{Message, MessageHandle, MessageID, Offset}; use crate::metadata::Metadata; use crate::shared::grpc::prost_timestamp_from_utc; use crate::tracker::Tracker; @@ -139,10 +138,10 @@ impl MapHandle { /// handle. pub(crate) async fn streaming_map( mut self, - input_stream: ReceiverStream, + input_stream: ReceiverStream, cln_token: CancellationToken, bypass_router: Option, - ) -> error::Result<(ReceiverStream, JoinHandle>)> { + ) -> error::Result<(ReceiverStream, JoinHandle>)> { let (output_tx, output_rx) = mpsc::channel(self.batch_size); let (error_tx, error_rx) = mpsc::channel(self.batch_size); let semaphore = Arc::new(Semaphore::new(self.concurrency)); @@ -223,7 +222,7 @@ impl MapHandle { /// Each message is processed in a separate spawned task. async fn process_concurrent_messages( &mut self, - input_stream: ReceiverStream, + input_stream: ReceiverStream, mut ctx: ConcurrentMapContext, ) -> error::Result<()> { let mut input_stream = input_stream; @@ -252,8 +251,8 @@ impl MapHandle { // if there are errors then we need to drain the stream and nack if self.shutting_down_on_err { - warn!(offset = ?read_msg.offset, error = ?self.final_result, "Map component is shutting down because of an error, not accepting the message"); - read_msg.ack_handle.as_ref().expect("ack handle should be present").is_failed.store(true, Ordering::Relaxed); + warn!(offset = ?read_msg.message().offset, error = ?self.final_result, "Map component is shutting down because of an error, not accepting the message"); + read_msg.mark_failed(self.final_result.as_ref().unwrap_err()); } else { let permit = Arc::clone(&ctx.semaphore).acquire_owned() .await.map_err(|e| Error::Mapper(format!("failed to acquire semaphore: {e}")))?; @@ -264,7 +263,7 @@ impl MapHandle { MapUnaryTask { mapper: mapper.clone(), permit, - message: read_msg, + msg_handle: read_msg, shared_ctx: Arc::clone(&ctx.shared_ctx), } .spawn(); @@ -273,7 +272,7 @@ impl MapHandle { MapStreamTask { mapper: mapper.clone(), permit, - message: read_msg, + msg_handle: read_msg, shared_ctx: Arc::clone(&ctx.shared_ctx), } .spawn(); @@ -290,7 +289,7 @@ impl MapHandle { /// Messages are collected into batches and processed synchronously. async fn process_batch_messages( &mut self, - input_stream: ReceiverStream, + input_stream: ReceiverStream, ctx: BatchMapContext, ) { let timeout_duration = self.read_timeout; @@ -300,27 +299,20 @@ impl MapHandle { let is_mono_vertex = is_mono_vertex(); // we don't need to tokio spawn here because, unlike unary and stream, batch is // a blocking operation, and we process one batch at a time. - while let Some(batch) = chunked_stream.next().await { + while let Some(read_batch) = chunked_stream.next().await { // if there are errors then we need to drain the stream and nack if self.shutting_down_on_err { - for msg in batch { - warn!(offset = ?msg.offset, error = ?self.final_result, "Map component is shutting down because of an error, not accepting the message"); - msg.ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + for read_msg in read_batch { + warn!(offset = ?read_msg.message().offset, error = ?self.final_result, "Map component is shutting down because of an error, not accepting the message"); + read_msg.mark_failed(self.final_result.as_ref().unwrap_err()); } continue; } - let ack_handles: Vec>> = - batch.iter().map(|msg| msg.ack_handle.clone()).collect(); - - if !batch.is_empty() + if !read_batch.is_empty() && let Err(e) = (MapBatchTask { mapper: ctx.batch_mapper.clone(), - batch, + msg_handles: read_batch, output_tx: ctx.output_tx.clone(), tracker: self.tracker.clone(), bypass_router: ctx.bypass_router.clone(), @@ -331,15 +323,6 @@ impl MapHandle { .await { error!(?e, "error received while performing batch map operation"); - // if there is an error, discard all the messages in the tracker and - // return the error. - for ack_handle in ack_handles { - ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); - } ctx.cln_token.cancel(); self.shutting_down_on_err = true; self.final_result = Err(e); @@ -350,7 +333,7 @@ impl MapHandle { /// Shared context for concurrent map tasks. pub(in crate::mapper) struct SharedMapTaskContext { - pub output_tx: mpsc::Sender, + pub output_tx: mpsc::Sender, pub error_tx: mpsc::Sender, pub tracker: Tracker, pub bypass_router: Option, @@ -369,7 +352,7 @@ struct ConcurrentMapContext { /// Context for batch map processing. struct BatchMapContext { - output_tx: mpsc::Sender, + output_tx: mpsc::Sender, cln_token: CancellationToken, bypass_router: Option, batch_mapper: UserDefinedBatchMap, @@ -462,7 +445,6 @@ pub(crate) struct ParentMessageInfo { /// one response for a single request. pub(crate) current_index: i32, pub(crate) metadata: Option>, - pub(crate) ack_handle: Option>, } impl From<&Message> for ParentMessageInfo { @@ -475,7 +457,6 @@ impl From<&Message> for ParentMessageInfo { start_time: Instant::now(), current_index: 0, metadata: message.metadata.clone(), - ack_handle: message.ack_handle.clone(), } } } @@ -532,7 +513,6 @@ impl From> for Message { } Some(Arc::new(metadata)) }, - ack_handle: value.1.ack_handle.clone(), } } } @@ -578,19 +558,19 @@ async fn create_response_stream( #[cfg(test)] mod tests { - use std::time::Duration; - use super::*; use crate::mapper::test_utils::MapperTestHandle; use crate::message::ReadAck; use crate::{ Result, - message::{MessageID, Offset, StringOffset}, + message::{MessageHandle, MessageID, Offset, StringOffset}, shared::grpc::create_rpc_channel, }; + use futures::StreamExt; use numaflow::shared::ServerExtras; use numaflow::{batchmap, map, mapstream}; use numaflow_pb::clients::map::map_client::MapClient; + use std::time::Duration; use tempfile::TempDir; use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; @@ -656,6 +636,7 @@ mod tests { }, ..Default::default() }; + let msg_handle: MessageHandle = message.into(); let (output_tx, mut output_rx) = mpsc::channel(10); @@ -681,7 +662,7 @@ mod tests { MapUnaryTask { mapper: unary_mapper, permit, - message, + msg_handle, shared_ctx, } .spawn(); @@ -690,7 +671,7 @@ mod tests { assert!(error_rx.recv().await.is_none()); let mapped_message = output_rx.recv().await.unwrap(); - assert_eq!(mapped_message.value, "hello"); + assert_eq!(mapped_message.message().value, "hello"); // we need to drop the mapper, because if there are any in-flight requests // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 @@ -763,7 +744,7 @@ mod tests { ) .await?; - let messages = vec![ + let messages: Vec = vec![ Message { typ: Default::default(), keys: Arc::from(vec!["first".into()]), @@ -778,7 +759,8 @@ mod tests { index: 0, }, ..Default::default() - }, + } + .into(), Message { typ: Default::default(), keys: Arc::from(vec!["second".into()]), @@ -793,7 +775,8 @@ mod tests { index: 1, }, ..Default::default() - }, + } + .into(), ]; let (input_tx, input_rx) = mpsc::channel(10); @@ -810,10 +793,10 @@ mod tests { let mut output_rx = output_stream.into_inner(); let mapped_message1 = output_rx.recv().await.unwrap(); - assert_eq!(mapped_message1.value, "hello"); + assert_eq!(mapped_message1.message().value, "hello"); let mapped_message2 = output_rx.recv().await.unwrap(); - assert_eq!(mapped_message2.value, "world"); + assert_eq!(mapped_message2.message().value, "world"); shutdown_tx .send(()) @@ -887,7 +870,7 @@ mod tests { ) .await?; - let message = Message { + let message: MessageHandle = Message { typ: Default::default(), keys: Arc::from(vec!["first".into()]), tags: None, @@ -901,7 +884,8 @@ mod tests { index: 0, }, ..Default::default() - }; + } + .into(); let (input_tx, input_rx) = mpsc::channel(10); let input_stream = ReceiverStream::new(input_rx); @@ -922,7 +906,7 @@ mod tests { // convert the bytes value to string and compare let values: Vec = responses .iter() - .map(|r| String::from_utf8(Vec::from(r.value.clone())).unwrap()) + .map(|r| String::from_utf8(Vec::from(r.message().value.clone())).unwrap()) .collect(); assert_eq!(values, vec!["test", "map", "stream"]); @@ -989,7 +973,6 @@ mod tests { let mut ack_rxs = vec![]; // send 10 requests to the mapper for i in 0..10 { - let (ack_tx, ack_rx) = oneshot::channel(); let message = Message { typ: Default::default(), keys: Arc::from(vec![format!("key_{}", i)]), @@ -1003,10 +986,11 @@ mod tests { offset: i.to_string().into(), index: i, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() }; - input_tx.send(message).await.unwrap(); + let (ack_tx, ack_rx) = oneshot::channel(); + let read_message = MessageHandle::new(message, ack_tx); + input_tx.send(read_message).await.unwrap(); ack_rxs.push(ack_rx); } @@ -1084,8 +1068,7 @@ mod tests { .await?; let (ack_tx1, ack_rx1) = oneshot::channel(); - let (ack_tx2, ack_rx2) = oneshot::channel(); - let messages = vec![ + let msg1 = MessageHandle::new( Message { typ: Default::default(), keys: Arc::from(vec!["first".into()]), @@ -1099,9 +1082,12 @@ mod tests { offset: "0".to_string().into(), index: 0, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx1))), ..Default::default() }, + ack_tx1, + ); + let (ack_tx2, ack_rx2) = oneshot::channel(); + let msg2 = MessageHandle::new( Message { typ: Default::default(), keys: Arc::from(vec!["second".into()]), @@ -1115,16 +1101,17 @@ mod tests { offset: "1".to_string().into(), index: 1, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx2))), ..Default::default() }, - ]; + ack_tx2, + ); + let read_messages = vec![msg1, msg2]; let (input_tx, input_rx) = mpsc::channel(10); let input_stream = ReceiverStream::new(input_rx); - for message in messages { - input_tx.send(message).await.unwrap(); + for read_message in read_messages { + input_tx.send(read_message).await.unwrap(); } let (_output_stream, map_handle) = mapper @@ -1210,7 +1197,6 @@ mod tests { let mut ack_rxs = vec![]; // send 10 requests to the mapper for i in 0..10 { - let (ack_tx, ack_rx) = oneshot::channel(); let message = Message { typ: Default::default(), keys: Arc::from(vec![format!("key_{}", i)]), @@ -1224,11 +1210,12 @@ mod tests { offset: i.to_string().into(), index: i, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() }; + let (ack_tx, ack_rx) = oneshot::channel(); + let read_message = MessageHandle::new(message, ack_tx); ack_rxs.push(ack_rx); - input_tx.send(message).await.unwrap(); + input_tx.send(read_message).await.unwrap(); } cln_token.cancelled().await; @@ -1255,8 +1242,8 @@ mod tests { Ok(()) } - fn create_default_msg(i: i32, ack_tx: oneshot::Sender) -> Message { - Message { + fn create_default_msg(i: i32) -> (MessageHandle, oneshot::Receiver) { + let message = Message { typ: Default::default(), keys: Arc::from(vec![format!("key_{}", i)]), tags: None, @@ -1269,9 +1256,10 @@ mod tests { offset: i.to_string().into(), index: i, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() - } + }; + let (ack_tx, ack_rx) = oneshot::channel(); + (MessageHandle::new(message, ack_tx), ack_rx) } #[tokio::test(flavor = "multi_thread")] @@ -1304,8 +1292,7 @@ mod tests { let mut ack_rxs = vec![]; // send 10 requests to the mapper for i in 0..10 { - let (ack_tx, ack_rx) = oneshot::channel(); - let message = create_default_msg(i, ack_tx); + let (message, ack_rx) = create_default_msg(i); input_tx.send(message).await.unwrap(); ack_rxs.push(ack_rx); } @@ -1346,8 +1333,7 @@ mod tests { .await; let (ack_tx1, ack_rx1) = oneshot::channel(); - let (ack_tx2, ack_rx2) = oneshot::channel(); - let messages = vec![ + let msg1 = MessageHandle::new( Message { typ: Default::default(), keys: Arc::from(vec!["first".into()]), @@ -1361,9 +1347,12 @@ mod tests { offset: "0".to_string().into(), index: 0, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx1))), ..Default::default() }, + ack_tx1, + ); + let (ack_tx2, ack_rx2) = oneshot::channel(); + let msg2 = MessageHandle::new( Message { typ: Default::default(), keys: Arc::from(vec!["second".into()]), @@ -1377,10 +1366,11 @@ mod tests { offset: "1".to_string().into(), index: 1, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx2))), ..Default::default() }, - ]; + ack_tx2, + ); + let messages = vec![msg1, msg2]; let (input_tx, input_rx) = mpsc::channel(10); let input_stream = ReceiverStream::new(input_rx); @@ -1438,8 +1428,7 @@ mod tests { let mut ack_rxs = vec![]; // send 10 requests to the mapper for i in 0..10 { - let (ack_tx, ack_rx) = oneshot::channel(); - let message = create_default_msg(i, ack_tx); + let (message, ack_rx) = create_default_msg(i); ack_rxs.push(ack_rx); input_tx .send(message) @@ -1620,7 +1609,6 @@ mod tests { start_time: std::time::Instant::now(), current_index: 0, metadata: None, - ack_handle: None, }; update_udf_write_metric(true, &msg_info, 5); @@ -1654,7 +1642,6 @@ mod tests { start_time: std::time::Instant::now(), current_index: 0, metadata: None, - ack_handle: None, }; update_udf_write_metric(false, &msg_info, 5); diff --git a/rust/numaflow-core/src/mapper/map/batch.rs b/rust/numaflow-core/src/mapper/map/batch.rs index b4bd426469..99a150a6d0 100644 --- a/rust/numaflow-core/src/mapper/map/batch.rs +++ b/rust/numaflow-core/src/mapper/map/batch.rs @@ -5,9 +5,10 @@ use super::{ use crate::config::is_mono_vertex; use crate::config::pipeline::VERTEX_TYPE_MAP_UDF; use crate::error::{Error, Result}; -use crate::message::Message; +use crate::message::{Message, MessageHandle}; use crate::monovertex::bypass_router::MvtxBypassRouter; use crate::tracker::Tracker; +use crate::{mark_failed, mark_success}; use numaflow_pb::clients::map::{self, MapRequest, MapResponse, map_client::MapClient}; use std::collections::HashMap; use std::sync::Arc; @@ -42,8 +43,8 @@ pub(in crate::mapper) struct BatchSenderMapState { /// MapBatchTask encapsulates all the context needed to execute a batch map operation. pub(in crate::mapper) struct MapBatchTask { pub mapper: UserDefinedBatchMap, - pub batch: Vec, - pub output_tx: mpsc::Sender, + pub msg_handles: Vec, + pub output_tx: mpsc::Sender, pub tracker: Tracker, pub bypass_router: Option, pub is_mono_vertex: bool, @@ -55,10 +56,18 @@ impl MapBatchTask { /// Returns an error if any message in the batch fails to be processed. pub async fn execute(self) -> Result<()> { // Store parent message info for each message before sending to UDF - let parent_infos: Vec = self.batch.iter().map(|m| m.into()).collect(); + let parent_infos: Vec = self + .msg_handles + .iter() + .map(|rm| rm.message().into()) + .collect(); // Convert Messages to MapRequests - let requests: Vec = self.batch.into_iter().map(|m| m.into()).collect(); + let requests: Vec = self + .msg_handles + .iter() + .map(|rm| rm.message().clone().into()) + .collect(); // Update read metrics for each request for _ in &requests { @@ -68,7 +77,10 @@ impl MapBatchTask { // Call the UDF and get results directly let results = self.mapper.batch(requests, self.cln_token).await; - for (result, parent_info) in results.into_iter().zip(parent_infos) { + for (result, (msg_handle, parent_info)) in results + .into_iter() + .zip(self.msg_handles.into_iter().zip(parent_infos)) + { match result { Ok(results) => { // Convert raw results to Messages using parent info @@ -96,29 +108,44 @@ impl MapBatchTask { .await?; for mapped_message in mapped_messages { - let bypassed = if let Some(ref bypass_router) = self.bypass_router { - bypass_router - .try_bypass(mapped_message.clone()) + // Each downstream handle shares the original ack tracking — ACK is + // deferred until all mapped messages are written to ISB/sink. + let downstream_handle = msg_handle.with_message(mapped_message); + + // Try to bypass the message. If bypassed, try_bypass takes ownership and returns None. + // If not bypassed, it returns Some(downstream_handle) for us to send downstream. + let downstream_handle = if let Some(ref bypass_router) = self.bypass_router + { + match bypass_router + .try_bypass(downstream_handle) .await .expect("failed to send message to bypass channel") + { + Some(msg) => msg, + None => continue, // Message was bypassed, move to next + } } else { - false + downstream_handle }; - if !bypassed { - self.output_tx - .send(mapped_message) - .await - .expect("failed to send response"); - } + self.output_tx + .send(downstream_handle) + .await + .expect("failed to send response"); } + + // Decrement the original ref_count for this message now that all downstream + // handles have been created and sent. + mark_success!(msg_handle); } Err(e) => { error!(err=?e, "failed to map message"); + mark_failed!(msg_handle, &e); return Err(e); } } } + Ok(()) } } diff --git a/rust/numaflow-core/src/mapper/map/stream.rs b/rust/numaflow-core/src/mapper/map/stream.rs index 170a5a2f6e..f3367610f1 100644 --- a/rust/numaflow-core/src/mapper/map/stream.rs +++ b/rust/numaflow-core/src/mapper/map/stream.rs @@ -1,11 +1,11 @@ use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; -use std::sync::atomic::Ordering; use crate::config::is_mono_vertex; use crate::error::{Error, Result}; -use crate::message::Message; +use crate::message::{Message, MessageHandle}; +use crate::{mark_failed, mark_success}; use numaflow_pb::clients::map::{self, MapRequest, MapResponse, map_client::MapClient}; use tokio::sync::{OwnedSemaphorePermit, mpsc}; use tokio_stream::StreamExt; @@ -44,7 +44,7 @@ pub(in crate::mapper) struct StreamSenderMapState { pub(in crate::mapper) struct MapStreamTask { pub mapper: UserDefinedStreamMap, pub permit: OwnedSemaphorePermit, - pub message: Message, + pub msg_handle: MessageHandle, pub shared_ctx: Arc, } @@ -64,9 +64,9 @@ impl MapStreamTask { // Store parent message info before sending to UDF // parent_info contains offset, so we don't need to clone it separately - let mut parent_info: ParentMessageInfo = (&self.message).into(); + let mut parent_info: ParentMessageInfo = self.msg_handle.message().into(); - let request: MapRequest = self.message.into(); + let request: MapRequest = self.msg_handle.message().clone().into(); update_udf_read_metric(self.shared_ctx.is_mono_vertex); // Call the UDF and get receiver for raw results @@ -106,53 +106,49 @@ impl MapStreamTask { .await .expect("failed to update tracker"); - let bypassed = + // Each downstream handle shares the original ack tracking — ACK is + // deferred until all mapped messages are written to ISB/sink. + let msg_handle = self.msg_handle.with_message(mapped_message); + + // Try to bypass the message. If bypassed, try_bypass takes ownership and returns None. + // If not bypassed, it returns Some(msg_handle) for us to send downstream. + let msg_handle = if let Some(ref bypass_router) = self.shared_ctx.bypass_router { - bypass_router - .try_bypass(mapped_message.clone()) + match bypass_router + .try_bypass(msg_handle) .await .expect("failed to send message to bypass channel") + { + Some(msg) => msg, + None => continue, // Message was bypassed, move to next + } } else { - false + msg_handle }; - if !bypassed { - self.shared_ctx - .output_tx - .send(mapped_message) - .await - .expect("failed to send response"); - } + self.shared_ctx + .output_tx + .send(msg_handle) + .await + .expect("failed to send response"); } } Some(Err(e)) => { error!(?e, "failed to map message"); - parent_info - .ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + mark_failed!(self.msg_handle, &e); let _ = self.shared_ctx.error_tx.send(e).await; return; } None => { - // Channel closed. If no results were ever sent (current_index == 0), - // this means the UDF stream may have closed unexpectedly (e.g., panic or gRPC - // stream error where the sender was dropped without delivering an error). - // Mark the message as failed so that it gets nacked. - if parent_info.current_index == 0 { - parent_info - .ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); - } + // Channel closed — stream ended cleanly (e.g., UDF returned empty results or + // finished after sending all results). Fall through to mark_success below. break; } } } + + // Decrement the original ref_count now that we've accounted for all downstream messages. + mark_success!(self.msg_handle); } } diff --git a/rust/numaflow-core/src/mapper/map/unary.rs b/rust/numaflow-core/src/mapper/map/unary.rs index 66e0fdbc22..00d662dfe7 100644 --- a/rust/numaflow-core/src/mapper/map/unary.rs +++ b/rust/numaflow-core/src/mapper/map/unary.rs @@ -1,11 +1,11 @@ use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; -use std::sync::atomic::Ordering; use crate::config::is_mono_vertex; use crate::error::{Error, Result}; -use crate::message::Message; +use crate::message::{Message, MessageHandle}; +use crate::{mark_failed, mark_success}; use numaflow_pb::clients::map::{self, MapRequest, MapResponse, map_client::MapClient}; use tokio::sync::{OwnedSemaphorePermit, mpsc, oneshot}; use tokio_stream::StreamExt; @@ -45,7 +45,7 @@ pub(in crate::mapper) struct MapUnaryTask { /// Permit to achieve structured concurrency by ensuring we do not exceed the concurrency limit /// and all the tasks are cleaned up when the component is shutting down. pub permit: OwnedSemaphorePermit, - pub message: Message, + pub msg_handle: MessageHandle, pub shared_ctx: Arc, } @@ -65,9 +65,9 @@ impl MapUnaryTask { // Store parent message info before sending to UDF // parent_info contains offset, so we don't need to clone it separately - let parent_info: ParentMessageInfo = (&self.message).into(); + let parent_info: ParentMessageInfo = self.msg_handle.message().into(); - let request: MapRequest = self.message.into(); + let request: MapRequest = self.msg_handle.message().clone().into(); update_udf_read_metric(self.shared_ctx.is_mono_vertex); // Call the UDF and get raw results @@ -79,12 +79,7 @@ impl MapUnaryTask { Ok(results) => results, Err(e) => { error!(?e, offset = ?parent_info.offset, "failed to map message"); - parent_info - .ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + mark_failed!(self.msg_handle, &e); let _ = self.shared_ctx.error_tx.send(e).await; return; } @@ -115,25 +110,40 @@ impl MapUnaryTask { .await .expect("failed to update tracker"); - // Send messages downstream for mapped_message in mapped_messages { - let bypassed = if let Some(ref bypass_router) = self.shared_ctx.bypass_router { - bypass_router - .try_bypass(mapped_message.clone()) + // Each downstream handle shares the original ack tracking — ACK is deferred until + // all mapped messages are written to ISB/sink. + let msg_handle = self.msg_handle.with_message(mapped_message); + + // Try to bypass the message. If bypassed, try_bypass takes ownership and returns None. + // If not bypassed, it returns Some(msg_handle) for us to send downstream. + let msg_handle = if let Some(ref bypass_router) = self.shared_ctx.bypass_router { + match bypass_router + .try_bypass(msg_handle) .await .expect("failed to send message to bypass channel") + { + Some(msg) => msg, + None => { + // Message was bypassed (already acked by bypass_router), move to next. + continue; + } + } } else { - false + msg_handle }; - if !bypassed { - self.shared_ctx - .output_tx - .send(mapped_message) - .await - .expect("failed to send response"); - } + self.shared_ctx + .output_tx + .send(msg_handle) + .await + .expect("failed to send response"); } + + // Decrement the original ref_count now that we've accounted for all downstream messages. + // The original msg_handle held ref_count=1; mark_success brings it to 0 contribution, + // and the downstream handles will each call mark_success when written to ISB/sink. + mark_success!(self.msg_handle); } } diff --git a/rust/numaflow-core/src/message.rs b/rust/numaflow-core/src/message.rs index b42efb2d75..0a3b13e68a 100644 --- a/rust/numaflow-core/src/message.rs +++ b/rust/numaflow-core/src/message.rs @@ -6,13 +6,80 @@ //! The spawned task exposes an [AckHandle] which implements [Drop] trait. As the message is processed, //! and cloned (e.g., flat-map), the reference counted Handle will keep track and eventually will be //! dropped once the all the copies of [Message] are dropped. This trigger the final ack/nak. +//! +//! [MessageHandle] is a wrapper around [Message] that holds the [AckHandle]. It should be explicitly +//! marked as success after processing via [MessageHandle::mark_success], so that it can be acked. +//! By default, it will be nacked if not marked as success. +//! +//! ## Macros +//! - [mark_success!] - Marks a single [MessageHandle] as success (consumes the handle). +//! - [mark_success_batch!] - Marks a batch of [MessageHandle]s as success (consumes the batch). +//! - [mark_failed!] - Marks a single [MessageHandle] as failed with a reason (consumes the handle). + +/// Marks a single [MessageHandle] as success (consumes the handle). +/// +/// # Example +/// ```ignore +/// mark_success!(message_handle); +/// ``` +#[macro_export] +macro_rules! mark_success { + ($msg:expr) => {{ + $msg.mark_success(); + }}; +} + +/// Marks a single [MessageHandle] as failed (consumes the handle), recording the failure reason. +/// +/// # Example +/// ```ignore +/// mark_failed!(message_handle, error); +/// ``` +#[macro_export] +macro_rules! mark_failed { + ($msg:expr, $err:expr) => {{ + $msg.mark_failed($err); + }}; +} + +/// Marks a batch of [MessageHandle]s as success (consumes the batch). +/// +/// # Example +/// ```ignore +/// mark_success_batch!(message_handles); +/// ``` +#[macro_export] +macro_rules! mark_success_batch { + ($batch:expr) => {{ + for msg in $batch { + msg.mark_success(); + } + }}; +} + +/// Marks a batch of [MessageHandle]s as failed (consumes the batch), recording the failure reason. +/// +/// # Example +/// ```ignore +/// mark_failed_batch!(message_handles, error); +/// ``` +#[macro_export] +macro_rules! mark_failed_batch { + ($batch:expr, $err:expr) => {{ + for msg in $batch { + msg.mark_failed($err); + } + }}; +} use crate::Error; use std::cmp::{Ordering, PartialEq}; use std::collections::HashMap; use std::fmt; use std::sync::Arc; -use std::sync::atomic::AtomicBool; +use std::sync::OnceLock; +use std::sync::atomic::AtomicUsize; +use tracing::{error, warn}; use crate::metadata::Metadata; use crate::shared::grpc::prost_timestamp_from_utc; @@ -53,41 +120,147 @@ pub(crate) struct Message { /// is_late is used to indicate if the message is a late data. Late data is data that arrives /// after the watermark has passed. This is set only at source. pub(crate) is_late: bool, - /// ack_handle is used to send the ack/nak to the source. It is optional because it is not used - /// when the message is originated from the WAL (reduce vertex). - pub(crate) ack_handle: Option>, } -/// AckHandle is used to send the ack/nak to the source but it is reference counted and makes sure -/// when it is dropped, we send the ack/nak to the source. +/// AckHandle is used to send the ack/nak to the source. It uses a reference count to track +/// the number of active references. When dropped, it sends NAK if ref_count != 0 (not all +/// references were marked as success), or ACK if ref_count == 0 (all references were marked +/// as success). If [MessageHandle::mark_failed] was called, the failure reason is logged at +/// NAK time. #[derive(Debug)] -pub(crate) struct AckHandle { - pub(crate) ack_handle: Option>, - pub(crate) is_failed: AtomicBool, +struct AckHandle { + sender: Option>, + /// Reference count to track active references. Starts at 1 when created. + /// Incremented when cloned (via MessageHandle::clone or with_message), decremented when + /// mark_success is called. On drop: NAK if ref_count != 0, ACK if ref_count == 0. + ref_count: AtomicUsize, + /// Set by mark_failed to record why the message is being nacked. + /// Uses OnceLock to capture only the first failure reason without locking overhead. + failure_reason: OnceLock, } impl AckHandle { - /// create a new AckHandle for a message. - pub(crate) fn new(ack_handle: oneshot::Sender) -> Self { + fn new(sender: oneshot::Sender) -> Self { Self { - ack_handle: Some(ack_handle), - is_failed: AtomicBool::new(false), + sender: Some(sender), + ref_count: AtomicUsize::new(1), + failure_reason: OnceLock::new(), } } } impl Drop for AckHandle { fn drop(&mut self) { - if let Some(ack_handle) = self.ack_handle.take() { - if self.is_failed.load(std::sync::atomic::Ordering::Relaxed) { - ack_handle.send(ReadAck::Nak).expect("Failed to send nak"); + if let Some(sender) = self.sender.take() { + // NAK if ref_count is not 0 (meaning not all references were marked as success) + let ack = if self.ref_count.load(std::sync::atomic::Ordering::Relaxed) != 0 { + if let Some(reason) = self.failure_reason.get() { + error!(reason = reason.as_str(), "message nacked due to failure"); + } + ReadAck::Nak } else { - ack_handle.send(ReadAck::Ack).expect("Failed to send ack"); + ReadAck::Ack + }; + if sender.send(ack).is_err() { + warn!( + "ack/nak receiver exited before receiving ack/nak; the listener task may have exited prematurely" + ); } } } } +/// MessageHandle is the message read from the ISB/Source. MessageHandle should be explicitly marked as +/// success after processing via [MessageHandle::mark_success], so that it can be acked/nacked. +/// By default, it will be nacked if not marked as success. +/// +/// MessageHandle implements Clone - when cloned, the reference count is incremented, so each +/// clone must be marked as success for the original message to be ACK'd. +#[derive(Debug)] +pub(crate) struct MessageHandle { + pub(crate) message: Message, + ack_handle: Arc, +} + +impl Clone for MessageHandle { + fn clone(&self) -> Self { + self.ack_handle + .ref_count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Self { + message: self.message.clone(), + ack_handle: Arc::clone(&self.ack_handle), + } + } +} + +impl MessageHandle { + /// Creates a new MessageHandle with the given ack sender. + /// The caller owns the corresponding receiver and awaits it to know when to ack/nak upstream. + pub(crate) fn new(message: Message, ack_tx: oneshot::Sender) -> Self { + Self { + message, + ack_handle: Arc::new(AckHandle::new(ack_tx)), + } + } + + /// Mark the message as successfully processed (consumes the handle). + /// This decrements the reference count. When all references are marked as success + /// (ref_count reaches 0), the message will be ACK'd when the AckHandle is dropped. + pub(crate) fn mark_success(self) { + self.ack_handle + .ref_count + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + + /// Mark the message as failed (consumes the handle), recording the reason it will be nacked. + /// ref_count is not decremented, so the message will be NAK'd when the AckHandle is dropped. + /// The error is logged at NAK time. + pub(crate) fn mark_failed(self, reason: impl fmt::Display) { + let _ = self.ack_handle.failure_reason.set(reason.to_string()); + } + + /// Creates a new MessageHandle with a different message but sharing this handle's ack tracking. + /// The ref_count is incremented so both handles must be marked as success for ACK. + /// Use this when fanning out one input message into multiple downstream messages (e.g., map). + pub(crate) fn with_message(&self, message: Message) -> Self { + self.ack_handle + .ref_count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Self { + message, + ack_handle: Arc::clone(&self.ack_handle), + } + } + + /// Get a reference to the inner message. + pub(crate) fn message(&self) -> &Message { + &self.message + } + + /// Get a mutable reference to the inner message. + pub(crate) fn message_mut(&mut self) -> &mut Message { + &mut self.message + } +} + +/// Converts a [Message] into a [MessageHandle] without ack tracking. +/// This is used for newly created messages (e.g., reduce output, watermark messages) +/// that don't need to be acked back to a source. +/// The ack handle is a no-op - mark_success() and drop will have no effect. +impl From for MessageHandle { + fn from(message: Message) -> Self { + Self { + message, + ack_handle: Arc::new(AckHandle { + sender: None, + ref_count: AtomicUsize::new(0), // Already "success" state + failure_reason: OnceLock::new(), + }), + } + } +} + /// Type of the [Message]. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub(crate) enum MessageType { @@ -141,7 +314,6 @@ impl Default for Message { metadata: None, typ: Default::default(), is_late: false, - ack_handle: None, } } } @@ -497,4 +669,74 @@ mod tests { let offset_string = Offset::String(string_offset); assert_eq!(format!("{}", offset_string), "42-1"); } + + #[tokio::test] + async fn test_read_message_nack_by_default() { + // When MessageHandle is dropped without calling mark_success, it should NAK + let message = Message::default(); + let (ack_tx, ack_rx) = oneshot::channel(); + let read_message = MessageHandle::new(message, ack_tx); + + // Drop the MessageHandle without calling mark_success + drop(read_message); + + // Should receive NAK + let result = ack_rx.await.unwrap(); + assert_eq!(result, ReadAck::Nak); + } + + #[tokio::test] + async fn test_read_message_ack_on_mark_success() { + // When mark_success is called, it should ACK + let message = Message::default(); + let (ack_tx, ack_rx) = oneshot::channel(); + let read_message = MessageHandle::new(message, ack_tx); + + // Mark as success (consumes the handle) + read_message.mark_success(); + + // Should receive ACK + let result = ack_rx.await.unwrap(); + assert_eq!(result, ReadAck::Ack); + } + + #[tokio::test] + async fn test_read_message_clone_all_success() { + // When cloning, all clones must be marked as success for ACK + let message = Message::default(); + let (ack_tx, ack_rx) = oneshot::channel(); + let read_message = MessageHandle::new(message, ack_tx); + + // Clone (simulating message split in map/transformer) + let cloned = read_message.clone(); + + // Mark both as success (each call consumes the handle) + read_message.mark_success(); + cloned.mark_success(); + + // Should receive ACK since all were marked as success + let result = ack_rx.await.unwrap(); + assert_eq!(result, ReadAck::Ack); + } + + #[tokio::test] + async fn test_read_message_clone_partial_success() { + // When only some clones are marked as success, it should NAK + let message = Message::default(); + let (ack_tx, ack_rx) = oneshot::channel(); + let read_message = MessageHandle::new(message, ack_tx); + + // Clone + let _cloned = read_message.clone(); + + // Only mark the original as success (consumes it), not the clone + read_message.mark_success(); + // Explicitly drop the clone before awaiting — _cloned must be dropped to + // trigger the NAK send, otherwise awaiting the channel would deadlock. + drop(_cloned); + + // Should receive NAK since not all were marked as success + let result = ack_rx.await.unwrap(); + assert_eq!(result, ReadAck::Nak); + } } diff --git a/rust/numaflow-core/src/monovertex/bypass_router.rs b/rust/numaflow-core/src/monovertex/bypass_router.rs index 0cf569daee..1e3a3ae5dc 100644 --- a/rust/numaflow-core/src/monovertex/bypass_router.rs +++ b/rust/numaflow-core/src/monovertex/bypass_router.rs @@ -44,11 +44,12 @@ use crate::config::is_mono_vertex; use crate::config::monovertex::BypassConditions; use crate::error; use crate::error::Error; -use crate::message::Message; +use crate::mark_success; +use crate::mark_success_batch; +use crate::message::{Message, MessageHandle}; use crate::shared::forward::should_forward; use crate::sinker::sink::{SinkWriter, send_drop_metrics}; use numaflow_models::models::ForwardConditions; -use std::sync::atomic::Ordering; use std::time::Duration; use tokio::pin; use tokio::sync::mpsc; @@ -62,9 +63,9 @@ use tracing::info; /// appropriate sink based on the bypass condition. #[derive(Debug, Clone)] pub enum MessageToSink { - Primary(Message), - Fallback(Message), - OnSuccess(Message), + Primary(MessageHandle), + Fallback(MessageHandle), + OnSuccess(MessageHandle), } /// Returns a reference to the inner message wrapped by the enum. @@ -73,7 +74,7 @@ impl MessageToSink { match self { MessageToSink::Primary(msg) | MessageToSink::Fallback(msg) - | MessageToSink::OnSuccess(msg) => msg, + | MessageToSink::OnSuccess(msg) => msg.message(), } } } @@ -153,33 +154,38 @@ impl MvtxBypassRouter { /// Checks if the message should be bypassed based on the bypass conditions and routes it to /// the appropriate sink. - /// Returns a boolean wrapped in a Result. Returns Ok(true) if the message was bypassed, - /// Ok(false) if the message was not bypassed, and Err if the messages supposed to be bypassed - /// but there was an error in sending the message to the bypass channel. - pub(crate) async fn try_bypass(&self, msg: Message) -> error::Result { + /// + /// Takes ownership of the MessageHandle. Returns: + /// - `Ok(None)` if the message was bypassed (sent to bypass channel, caller should not process further) + /// - `Ok(Some(message_handle))` if the message was not bypassed (caller should continue processing) + /// - `Err` if the message was supposed to be bypassed but there was an error sending to bypass channel + pub(crate) async fn try_bypass( + &self, + message_handle: MessageHandle, + ) -> error::Result> { for bypass_condition in self.bypass_conditions.clone() { match bypass_condition { BypassConditionState::Sink(sink) => { - if should_forward(msg.tags.clone(), Some(sink)) { - return self.route(MessageToSink::Primary(msg)).await.map(|_| true); + if should_forward(message_handle.message().tags.clone(), Some(sink)) { + self.route(MessageToSink::Primary(message_handle)).await?; + return Ok(None); } } BypassConditionState::Fallback(fallback) => { - if should_forward(msg.tags.clone(), Some(fallback)) { - return self.route(MessageToSink::Fallback(msg)).await.map(|_| true); + if should_forward(message_handle.message().tags.clone(), Some(fallback)) { + self.route(MessageToSink::Fallback(message_handle)).await?; + return Ok(None); } } BypassConditionState::OnSuccess(on_success) => { - if should_forward(msg.tags.clone(), Some(on_success)) { - return self - .route(MessageToSink::OnSuccess(msg)) - .await - .map(|_| true); + if should_forward(message_handle.message().tags.clone(), Some(on_success)) { + self.route(MessageToSink::OnSuccess(message_handle)).await?; + return Ok(None); } } } } - Ok(false) + Ok(Some(message_handle)) } /// [route] method calls the bypass_tx send method. @@ -247,28 +253,32 @@ impl BypassRouterReceiver { // Main processing loop while let Some(batch) = chunk_stream.next().await { // we are in shutting down mode, we will not be writing to any sink, - // mark the messages as failed, and on Drop they will be nack'ed. + // messages will be nack'ed. if self.shutting_down_on_err { - for msg in &batch { - match msg { - MessageToSink::Primary(msg) - | MessageToSink::Fallback(msg) - | MessageToSink::OnSuccess(msg) => msg.ack_handle.as_ref(), - } - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + for msg in batch { + let handle = match msg { + MessageToSink::Primary(h) + | MessageToSink::Fallback(h) + | MessageToSink::OnSuccess(h) => h, + }; + handle.mark_failed(self.final_result.as_ref().unwrap_err()); } continue; } - // filter out messages that are marked for drop - let original_len = batch.len(); - let batch: Vec<_> = batch - .into_iter() - .filter(|msg| !msg.inner().dropped()) - .collect(); - let dropped_message_count = original_len - batch.len(); + // Separate messages marked for drop from those to be forwarded. + // Dropped messages must be ACK'd — drop is a successful outcome. + let (to_drop, batch): (Vec<_>, Vec<_>) = + batch.into_iter().partition(|msg| msg.inner().dropped()); + let dropped_message_count = to_drop.len(); + for msg in to_drop { + let handle = match msg { + MessageToSink::Primary(h) + | MessageToSink::Fallback(h) + | MessageToSink::OnSuccess(h) => h, + }; + mark_success!(handle); + } // skip if all were dropped if batch.is_empty() { @@ -278,28 +288,27 @@ impl BypassRouterReceiver { let mut primary_messages: Vec = vec![]; let mut fallback_messages: Vec = vec![]; let mut on_success_messages: Vec = vec![]; - let mut ack_handles = vec![]; + let mut msg_handles: Vec = vec![]; - // Convert MessageToSink to Message and create respective - // vectors of Messages and ack handles + // Convert MessageToSink to Message and collect MessageHandles for acking for msg in batch { match msg { - MessageToSink::Primary(msg) => { - ack_handles.push(msg.ack_handle.clone()); - primary_messages.push(msg); + MessageToSink::Primary(read_msg) => { + primary_messages.push(read_msg.message().clone()); + msg_handles.push(read_msg); } - MessageToSink::Fallback(msg) => { - ack_handles.push(msg.ack_handle.clone()); - fallback_messages.push(msg) + MessageToSink::Fallback(read_msg) => { + fallback_messages.push(read_msg.message().clone()); + msg_handles.push(read_msg); } - MessageToSink::OnSuccess(msg) => { - ack_handles.push(msg.ack_handle.clone()); - on_success_messages.push(msg) + MessageToSink::OnSuccess(read_msg) => { + on_success_messages.push(read_msg.message().clone()); + msg_handles.push(read_msg); } } } - if let Err(e) = self + match self .perform_write( primary_messages, fallback_messages, @@ -308,19 +317,19 @@ impl BypassRouterReceiver { ) .await { - error!(?e, "Error writing to sink, initiating shutdown."); - cln_token.cancel(); - - for ack_handle in ack_handles { - ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + Ok(()) => { + // Successfully written to sink, ACK all messages + mark_success_batch!(msg_handles); + } + Err(e) => { + error!(?e, "Error writing to sink, initiating shutdown."); + cln_token.cancel(); + for msg in msg_handles { + msg.mark_failed(&e); + } + self.final_result = Err(e); + self.shutting_down_on_err = true; } - - self.final_result = Err(e); - self.shutting_down_on_err = true; } send_drop_metrics(is_mono_vertex(), dropped_message_count); } @@ -332,7 +341,7 @@ impl BypassRouterReceiver { } /// Call the different write methods on the sink writer for different types of message vectors. - /// If any of the write methods fail, return immediately, ack_handles are marked as failed in the + /// If any of the write methods fail, return immediately, msg_handles are marked as failed in the /// caller. async fn perform_write( &mut self, @@ -357,7 +366,7 @@ impl BypassRouterReceiver { #[cfg(test)] mod tests { use super::*; - use crate::message::{AckHandle, IntOffset, MessageID, Offset, ReadAck}; + use crate::message::{IntOffset, MessageID, Offset, ReadAck}; use crate::shared::grpc::create_rpc_channel; use crate::sinker::sink::{SinkClientType, SinkWriterBuilder}; use crate::tracker::Tracker; @@ -373,20 +382,9 @@ mod tests { use tokio::sync::mpsc::Receiver; use tokio::sync::oneshot; - /// Creates a test message with optional tags and ack handle. - fn create_test_message( - id: i32, - tags: Option>, - with_ack_handle: bool, - ) -> (Message, Option>) { - let (ack_handle, ack_rx) = if with_ack_handle { - let (tx, rx) = oneshot::channel(); - (Some(Arc::new(AckHandle::new(tx))), Some(rx)) - } else { - (None, None) - }; - - let msg = Message { + /// Creates a test message with optional tags. + fn create_test_message(id: i32, tags: Option>) -> Message { + Message { typ: Default::default(), keys: Arc::from(vec![format!("key_{}", id)]), tags: tags.map(Arc::from), @@ -402,22 +400,31 @@ mod tests { headers: Arc::new(HashMap::new()), metadata: None, is_late: false, - ack_handle, - }; - (msg, ack_rx) + } + } + + /// Creates a test MessageHandle with optional tags and ack handle. + fn create_test_read_message( + id: i32, + tags: Option>, + ) -> (MessageHandle, oneshot::Receiver) { + let msg = create_test_message(id, tags); + let (ack_tx, ack_rx) = oneshot::channel(); + (MessageHandle::new(msg, ack_tx), ack_rx) } // ==================== MessageToSink Tests ==================== #[test] fn test_message_to_sink_inner() { - let (msg, _) = create_test_message(1, None, false); + let msg = create_test_message(1, None); + let (read_msg, _ack_rx) = create_test_read_message(1, None); // Test all variants return the correct inner message for msg_to_sink in [ - MessageToSink::Primary(msg.clone()), - MessageToSink::Fallback(msg.clone()), - MessageToSink::OnSuccess(msg.clone()), + MessageToSink::Primary(read_msg.clone()), + MessageToSink::Fallback(read_msg.clone()), + MessageToSink::OnSuccess(read_msg.clone()), ] { assert_eq!(msg_to_sink.inner().id, msg.id); assert_eq!(msg_to_sink.inner().value, msg.value); @@ -511,11 +518,9 @@ mod tests { .await; let mut ack_rxs = vec![]; - let messages: Vec = (0..10) + let messages: Vec = (0..10) .map(|i| { - let (ack_tx, ack_rx) = oneshot::channel(); - ack_rxs.push(ack_rx); - Message { + let message = Message { typ: Default::default(), keys: Arc::from(vec![format!("key_{}", i)]), tags: Some(Arc::from(vec![DROP.to_string()])), @@ -528,9 +533,12 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() - } + }; + let (ack_tx, ack_rx) = oneshot::channel(); + let handle = MessageHandle::new(message, ack_tx); + ack_rxs.push(ack_rx); + handle }) .collect(); diff --git a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs index 81b960b428..f1abdc0a31 100644 --- a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs @@ -421,7 +421,6 @@ mod simple_buffer_tests { headers: Arc::new(HashMap::new()), metadata: None, is_late: false, - ack_handle: None, }; writer.write(msg).await.expect("write should succeed"); } diff --git a/rust/numaflow-core/src/pipeline/isb/jetstream/js_reader.rs b/rust/numaflow-core/src/pipeline/isb/jetstream/js_reader.rs index 3bbb00010b..036d805925 100644 --- a/rust/numaflow-core/src/pipeline/isb/jetstream/js_reader.rs +++ b/rust/numaflow-core/src/pipeline/isb/jetstream/js_reader.rs @@ -89,7 +89,6 @@ impl JSWrappedMessage { watermark: None, metadata: header.metadata.map(|m| Arc::new(Metadata::from(m))), is_late: message_info.is_late, - ack_handle: None, }) } } diff --git a/rust/numaflow-core/src/pipeline/isb/reader.rs b/rust/numaflow-core/src/pipeline/isb/reader.rs index b51b00dcce..ecdaea44f9 100644 --- a/rust/numaflow-core/src/pipeline/isb/reader.rs +++ b/rust/numaflow-core/src/pipeline/isb/reader.rs @@ -16,7 +16,11 @@ use crate::config::get_vertex_name; use crate::config::pipeline::VertexType::ReduceUDF; use crate::config::pipeline::isb::{BufferReaderConfig, ISBConfig, Stream}; use crate::error::Error; -use crate::message::{AckHandle, IntOffset, Message, MessageType, Offset, ReadAck}; +#[cfg(test)] +use crate::mark_success; +#[cfg(test)] +use crate::mark_success_batch; +use crate::message::{IntOffset, Message, MessageHandle, MessageType, Offset, ReadAck}; use crate::metrics::{ PIPELINE_PARTITION_NAME_LABEL, jetstream_isb_error_metrics_labels, jetstream_isb_metrics_labels, pipeline_metric_labels, pipeline_metrics, @@ -114,14 +118,14 @@ impl ISBReaderOrchestrator { }) } - /// Streaming read from ISB, returns a ReceiverStream and a JoinHandle for monitoring errors. + /// Streaming read from ISB, returns a ReceiverStream of MessageHandle and a JoinHandle for monitoring errors. pub(crate) async fn streaming_read( mut self, cancel: CancellationToken, - ) -> Result<(ReceiverStream, JoinHandle>)> { + ) -> Result<(ReceiverStream, JoinHandle>)> { let max_ack_pending = self.cfg.max_ack_pending; let batch_size = std::cmp::min(self.batch_size, max_ack_pending); - let (tx, rx) = mpsc::channel(batch_size); + let (tx, rx) = mpsc::channel::(batch_size); let handle: JoinHandle> = tokio::spawn(async move { let semaphore = Arc::new(Semaphore::new(max_ack_pending)); @@ -313,7 +317,7 @@ impl ISBReaderOrchestrator { vertex_type: &str, partition: u16, idle_wmb: WMB, - tx: &mpsc::Sender, + tx: &mpsc::Sender, ) -> Result<()> { if vertex_type != ReduceUDF.as_str() { return Ok(()); @@ -333,7 +337,9 @@ impl ISBReaderOrchestrator { }, ..Default::default() }; - tx.send(msg).await.map_err(|_| { + + let read_msg: MessageHandle = msg.into(); + tx.send(read_msg).await.map_err(|_| { Error::ISB(crate::pipeline::isb::error::ISBError::Other( "Failed to send wmb message to channel".to_string(), )) @@ -444,7 +450,7 @@ impl ISBReaderOrchestrator { async fn handle_idle_watermarks( &mut self, batch_is_empty: bool, - tx: &mpsc::Sender, + tx: &mpsc::Sender, ) -> Result<()> { if batch_is_empty { if let Some(wm) = self.watermark.as_mut() { @@ -482,7 +488,7 @@ impl ISBReaderOrchestrator { async fn process_message_batch( &mut self, mut batch: Vec, - tx: &mpsc::Sender, + tx: &mpsc::Sender, permits: &mut tokio::sync::OwnedSemaphorePermit, cancel: CancellationToken, processing_start: Instant, @@ -504,11 +510,11 @@ impl ISBReaderOrchestrator { Self::publish_read_metrics(&self.metric_labels, &message); let (ack_tx, ack_rx) = oneshot::channel(); - message.ack_handle = Some(Arc::new(AckHandle::new(ack_tx))); + let read_message = MessageHandle::new(message, ack_tx); // Start message tracking and WIP loop self.start_message_tracking( - &message, + read_message.message(), permits.split(1).expect("Failed to split permit"), cancel.clone(), processing_start, @@ -516,8 +522,7 @@ impl ISBReaderOrchestrator { ) .await?; - // Send message to channel - if tx.send(message).await.is_err() { + if tx.send(read_message).await.is_err() { break; } } @@ -800,7 +805,8 @@ mod tests { 10, "Expected 10 messages from the jetstream reader" ); - drop(buffer); + // Mark all messages as success to ACK them + mark_success_batch!(buffer); reader_cancel_token.cancel(); js_reader_task.await.unwrap().unwrap(); @@ -901,9 +907,11 @@ mod tests { } for _ in 0..5 { - let Some(_val) = js_reader_rx.next().await else { + let Some(val) = js_reader_rx.next().await else { break; }; + // Mark as success to ACK the message + mark_success!(val); } // wait until the tracker becomes empty, don't wait more than 1 second @@ -1059,14 +1067,21 @@ mod tests { // Verify the message was correctly decompressed assert_eq!( - received_message.value.len(), + received_message.message().value.len(), 0, "Empty payload should remain empty after compression/decompression" ); - assert_eq!(received_message.keys.as_ref(), &["empty_key".to_string()]); - assert_eq!(received_message.offset.to_string(), offset.to_string()); + assert_eq!( + received_message.message().keys.as_ref(), + &["empty_key".to_string()] + ); + assert_eq!( + received_message.message().offset.to_string(), + offset.to_string() + ); - drop(received_message); + // Mark as success to ACK the message (this consumes the message) + mark_success!(received_message); reader_cancel_token.cancel(); js_reader_task.await.unwrap().unwrap(); @@ -1373,14 +1388,12 @@ mod tests { #[cfg(test)] mod simplebuffer_tests { use super::*; - use crate::pipeline::isb::simplebuffer::{SimpleBufferAdapter, WithSimpleBuffer}; - use numaflow_testing::simplebuffer::SimpleBuffer; - use std::collections::HashMap; - use std::sync::atomic::Ordering; - use crate::message::MessageID; + use crate::pipeline::isb::simplebuffer::{SimpleBufferAdapter, WithSimpleBuffer}; use bytes::Bytes; use chrono::Utc; + use numaflow_testing::simplebuffer::SimpleBuffer; + use std::collections::HashMap; use tokio::time::sleep; use tokio_stream::StreamExt; use tokio_util::sync::CancellationToken; @@ -1446,7 +1459,6 @@ mod simplebuffer_tests { headers: Arc::new(HashMap::new()), metadata: None, is_late: false, - ack_handle: None, }; writer.write(msg).await.expect("write should succeed"); } @@ -1471,8 +1483,10 @@ mod simplebuffer_tests { } assert_eq!(received.len(), 5, "Should receive all 5 messages"); - // Ack all messages by dropping them (default behavior is ack on drop) - drop(received); + // Ack all messages by marking them as success + for msg in received { + mark_success!(msg); + } // Wait for tracker to become empty (all acks processed) tokio::time::timeout(Duration::from_secs(2), async { @@ -1518,8 +1532,10 @@ mod simplebuffer_tests { ); assert_eq!(received.len(), 3); - // Ack all messages by dropping - drop(received); + // Ack all messages by marking them as success + for msg in received { + mark_success!(msg); + } cancel.cancel(); handle.await.unwrap().unwrap(); @@ -1667,8 +1683,8 @@ mod simplebuffer_tests { // Read message let msg = rx.next().await.expect("Should receive message"); - // Ack by dropping msg (is_failed defaults to false) - drop(msg); + // Ack by marking as success + mark_success!(msg); // Tracker should become empty tokio::time::timeout(Duration::from_secs(2), async { @@ -1697,12 +1713,9 @@ mod simplebuffer_tests { // Read message first time let msg1 = rx.next().await.expect("Should receive message first time"); - let payload1 = msg1.value.clone(); + let payload1 = msg1.message.value.clone(); - // Nack it by setting is_failed and dropping - if let Some(h) = &msg1.ack_handle { - h.is_failed.store(true, Ordering::Relaxed); - } + // Nack it by dropping without calling mark_success (ref_count != 0 causes NAK on drop) drop(msg1); // After nacking, the message goes back to Pending state and will be refetched. @@ -1713,12 +1726,12 @@ mod simplebuffer_tests { .expect("Stream should not end"); assert_eq!( - msg2.value, payload1, + msg2.message.value, payload1, "Redelivered message should have same payload" ); - // Ack it this time by dropping (is_failed defaults to false) - drop(msg2); + // Ack it this time by marking as success + mark_success!(msg2); // Wait for final ack let result = tokio::time::timeout(Duration::from_secs(1), async { @@ -1755,8 +1768,10 @@ mod simplebuffer_tests { // Cancel while messages are inflight cancel.cancel(); - // Ack messages after cancellation by dropping them - drop(messages); + // Ack messages after cancellation by marking them as success + for msg in messages { + mark_success!(msg); + } // Handle should complete (cleanup waits for inflight) let result = tokio::time::timeout(Duration::from_secs(2), handle) @@ -1799,9 +1814,9 @@ mod simplebuffer_tests { "4th message should block due to backpressure" ); - // Ack first message to free a permit by taking it out and dropping + // Ack first message to free a permit let first_msg = inflight.remove(0); - drop(first_msg); + mark_success!(first_msg); // Now 4th message should come through let fourth = tokio::time::timeout(Duration::from_millis(500), rx.next()) @@ -1810,8 +1825,10 @@ mod simplebuffer_tests { .expect("Stream should not end"); inflight.push(fourth); - // Cleanup - ack remaining by dropping - drop(inflight); + // Cleanup - ack remaining + for msg in inflight { + mark_success!(msg); + } cancel.cancel(); let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; @@ -1838,8 +1855,10 @@ mod simplebuffer_tests { } assert_eq!(messages.len(), 10); - // Drop all messages to trigger ack (is_failed defaults to false) - drop(messages); + // Ack all messages by marking as success + for msg in messages { + mark_success!(msg); + } // Wait for all ack operations to complete let result = tokio::time::timeout(Duration::from_secs(2), async { diff --git a/rust/numaflow-core/src/pipeline/isb/simplebuffer.rs b/rust/numaflow-core/src/pipeline/isb/simplebuffer.rs index 835fc36479..3909c27e3f 100644 --- a/rust/numaflow-core/src/pipeline/isb/simplebuffer.rs +++ b/rust/numaflow-core/src/pipeline/isb/simplebuffer.rs @@ -138,7 +138,6 @@ fn convert_message(read_msg: ReadMessage) -> Message { headers: Arc::new(read_msg.headers), metadata: None, is_late: false, - ack_handle: None, } } @@ -311,7 +310,6 @@ mod tests { headers: Arc::new(HashMap::new()), metadata: None, is_late: false, - ack_handle: None, } } diff --git a/rust/numaflow-core/src/pipeline/isb/writer.rs b/rust/numaflow-core/src/pipeline/isb/writer.rs index 9bbf3345fa..c9c38f7c06 100644 --- a/rust/numaflow-core/src/pipeline/isb/writer.rs +++ b/rust/numaflow-core/src/pipeline/isb/writer.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; -use std::sync::atomic::Ordering; +use tokio::sync::OwnedSemaphorePermit; use tokio::sync::Semaphore; use tokio::task::JoinHandle; use tokio::time::{Duration, Instant, sleep}; @@ -9,11 +9,11 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tracing::{debug, error, warn}; -use crate::Result; use crate::config::pipeline::isb::{BufferFullStrategy, Stream}; use crate::config::pipeline::{ToVertexConfig, VertexType}; use crate::error::Error; -use crate::message::{Message, Offset}; +use crate::mark_success; +use crate::message::{Message, MessageHandle, Offset}; use crate::metrics::{ PIPELINE_PARTITION_NAME_LABEL, pipeline_drop_metric_labels, pipeline_metric_labels, pipeline_metrics, @@ -23,6 +23,7 @@ use crate::pipeline::isb::{ISBWriter, PendingWrite, WriteError, WriteResult}; use crate::shared::forward; use crate::typ::NumaflowTypeConfig; use crate::watermark::WatermarkHandle; +use crate::{Result, mark_failed}; const DEFAULT_RETRY_INTERVAL_MILLIS: u64 = 10; @@ -39,6 +40,92 @@ struct PendingWriteResult { write_start: Instant, } +/// ISBWriteTask encapsulates all the context needed to execute a write operation per message. +/// Similar to MapUnaryTask, it spawns as a tokio task and calls mark_success() at the end. +struct ISBWriteTask { + orchestrator: ISBWriterOrchestrator, + /// Permit to achieve structured concurrency by ensuring we do not exceed the concurrency limit + /// and all the tasks are cleaned up when the component is shutting down. + permit: OwnedSemaphorePermit, + msg_handle: MessageHandle, + cln_token: CancellationToken, +} + +impl ISBWriteTask { + /// Spawns the ISB write task as a tokio task. + /// The task will write the message to ISB, resolve PAFs, publish watermarks, and mark success. + fn spawn(self) { + tokio::spawn(async move { + self.execute().await; + }); + } + + /// Executes the ISB write operation. + /// Flow: check dropped -> route and write -> resolve PAFs -> publish watermarks -> mark_success + async fn execute(self) { + // Hold the permit until the task completes + let _permit = self.permit; + + let message = self.msg_handle.message().clone(); + + // Handle dropped messages + if message.dropped() { + // Increment metric for user-initiated drops via DROP tag + pipeline_metrics() + .forwarder + .udf_drop_total + .get_or_create(pipeline_metric_labels( + self.orchestrator.vertex_type.as_str(), + )) + .inc(); + // Mark as success since dropped messages are intentionally dropped + mark_success!(self.msg_handle); + return; + } + + // Route and write to appropriate streams + let write_results = self + .orchestrator + .route_and_write_message(&message, self.cln_token.clone()) + .await; + + let n = write_results.len(); + + // Resolve all PAFs + let resolved_offsets = self + .orchestrator + .resolve_all_pafs(write_results, &message, self.cln_token.clone()) + .await; + + // If any of the writes failed, NAK the message + if resolved_offsets.len() != n { + warn!( + expected = n, + actual = resolved_offsets.len(), + "Some writes failed during PAF resolution, message will be NAK'd" + ); + mark_failed!( + self.msg_handle, + format!( + "PAF resolution failed: {}/{} writes succeeded", + resolved_offsets.len(), + n + ) + ); + return; + } + + // Publish watermarks for successful writes + self.orchestrator + .clone() + .publish_watermarks_for_offsets(resolved_offsets, &message) + .await; + + // All PAFs resolved successfully, mark as success to ACK + mark_success!(self.msg_handle); + } +} + /// Components needed to create an ISBWriterOrchestrator. pub(crate) struct ISBWriterOrchestratorComponents { pub config: Vec, @@ -108,24 +195,40 @@ impl ISBWriterOrchestrator { } /// Starts reading messages from the stream and writes them to Jetstream ISB. + /// Each message is processed by spawning an ISBWriteTask that handles the full write flow: + /// check dropped -> route and write -> resolve PAFs -> publish watermarks -> mark_success pub(crate) async fn streaming_write( self, - messages_stream: ReceiverStream, + messages_stream: ReceiverStream, cln_token: CancellationToken, ) -> Result>> { let handle: JoinHandle> = tokio::spawn(async move { let mut messages_stream = messages_stream; - while let Some(message) = messages_stream.next().await { - self.write_to_isb(message, cln_token.clone()) - .await - .inspect_err(|e| { - error!(?e, "Failed to process message"); + while let Some(msg_handle) = messages_stream.next().await { + // Acquire permit for structured concurrency + let permit = match Arc::clone(&self.sem).acquire_owned().await { + Ok(permit) => permit, + Err(_) => { + error!("Failed to acquire semaphore permit"); cln_token.cancel(); - })?; + return Err(Error::ISB(ISBError::Other( + "Failed to acquire semaphore permit".to_string(), + ))); + } + }; + + // Spawn the write task + ISBWriteTask { + orchestrator: self.clone(), + permit, + msg_handle, + cln_token: cln_token.clone(), + } + .spawn(); } - // Wait for all the PAF resolvers to complete before returning + // Wait for all the write tasks to complete before returning self.wait_for_paf_resolvers().await?; Ok(()) @@ -133,32 +236,6 @@ impl ISBWriterOrchestrator { Ok(handle) } - /// Writes a single message to the ISB. It will keep retrying until it succeeds or is cancelled. - /// Writes are ordered only if PAF concurrency is 1 because during retries we cannot guarantee order. - /// This calls `write_to_stream` internally once it has figured out the target streams. This Write - /// is like a `flap-map` operation. It could end in 0, 1, or more writes based on conditions. - async fn write_to_isb(&self, message: Message, cln_token: CancellationToken) -> Result<()> { - // Handle dropped messages - if message.dropped() { - // Increment metric for user-initiated drops via DROP tag - pipeline_metrics() - .forwarder - .udf_drop_total - .get_or_create(pipeline_metric_labels(self.vertex_type.as_str())) - .inc(); - return Ok(()); - } - - // Route and write to appropriate streams - let write_results = self - .route_and_write_message(&message, cln_token.clone()) - .await; - - // Resolve PAFs and finalize - self.resolve_and_finalize(write_results, message, cln_token) - .await - } - /// Routes a message to appropriate streams and writes to each. /// Returns a list of PendingWriteResults (one per successful write) with PAFs to be resolved. async fn route_and_write_message( @@ -368,54 +445,6 @@ impl ISBWriterOrchestrator { .observe(write_processing_start.elapsed().as_micros() as f64); } - /// Resolves PAFs and finalizes the message processing. - /// Spawns a background task to resolve all PAFs, publish watermarks, and update tracker. - async fn resolve_and_finalize( - &self, - write_results: Vec, - message: Message, - cln_token: CancellationToken, - ) -> Result<()> { - let permit = Arc::clone(&self.sem).acquire_owned().await.map_err(|_e| { - Error::ISB(ISBError::Other( - "Failed to acquire semaphore permit".to_string(), - )) - }); - - let mut this = self.clone(); - tokio::spawn(async move { - let _permit = permit; - - let n = write_results.len(); - // Resolve all PAFs - let resolved_offsets = this - .resolve_all_pafs(write_results, &message, cln_token) - .await; - - // If any of the writes failed, NAK the message so it can be retried - if resolved_offsets.len() != n { - message - .ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); - warn!( - expected = n, - actual = resolved_offsets.len(), - "Some writes failed during PAF resolution, message will be NAK'd" - ); - return; - } - - // Publish watermarks for successful writes - this.publish_watermarks_for_offsets(resolved_offsets, &message) - .await; - }); - - Ok(()) - } - /// Resolves all PAFs and returns the offsets for successful writes. async fn resolve_all_pafs( &self, @@ -573,7 +602,7 @@ impl ISBWriterOrchestrator { mod tests { use super::*; use crate::config::pipeline::isb::BufferWriterConfig; - use crate::message::{AckHandle, IntOffset, Message, MessageID, Offset, ReadAck}; + use crate::message::{IntOffset, Message, MessageHandle, MessageID, Offset, ReadAck}; use crate::pipeline::isb::jetstream::js_writer::JetStreamWriter; use crate::typ::WithoutRateLimiter; use async_nats::jetstream; @@ -678,11 +707,11 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() }; + let read_message = MessageHandle::new(message, ack_tx); ack_rxs.push(ack_rx); - tx.send(message).await.unwrap(); + tx.send(read_message).await.unwrap(); } drop(tx); @@ -789,11 +818,11 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() }; + let read_message = MessageHandle::new(message, ack_tx); ack_rxs.push(ack_rx); - tx.send(message).await.unwrap(); + tx.send(read_message).await.unwrap(); } // Cancel after sending some messages @@ -964,11 +993,11 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() }; + let read_message = MessageHandle::new(message, ack_tx); ack_rxs.push(ack_rx); - tx.send(message).await.unwrap(); + tx.send(read_message).await.unwrap(); } drop(tx); @@ -1030,7 +1059,7 @@ mod tests { mod simple_buffer_tests { use super::*; use crate::config::pipeline::isb::BufferWriterConfig; - use crate::message::{AckHandle, IntOffset, MessageID, ReadAck}; + use crate::message::{IntOffset, MessageHandle, MessageID, ReadAck}; use crate::pipeline::isb::simplebuffer::{SimpleBufferAdapter, WithSimpleBuffer}; use bytes::Bytes; use chrono::Utc; @@ -1043,7 +1072,7 @@ mod simple_buffer_tests { id: i64, value: &str, tags: Option>, - ) -> (Message, tokio::sync::oneshot::Receiver) { + ) -> (MessageHandle, tokio::sync::oneshot::Receiver) { let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); let message = Message { typ: Default::default(), @@ -1061,9 +1090,8 @@ mod simple_buffer_tests { headers: Arc::new(HashMap::new()), metadata: None, is_late: false, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), }; - (message, ack_rx) + (MessageHandle::new(message, ack_tx), ack_rx) } /// Helper to create ISBWriterOrchestrator with a single SimpleBuffer @@ -1886,7 +1914,6 @@ mod simple_buffer_tests { headers: Arc::new(HashMap::new()), metadata: None, is_late: false, - ack_handle: None, } } diff --git a/rust/numaflow-core/src/reduce/pbq.rs b/rust/numaflow-core/src/reduce/pbq.rs index 52c57058a2..8d6787a4f4 100644 --- a/rust/numaflow-core/src/reduce/pbq.rs +++ b/rust/numaflow-core/src/reduce/pbq.rs @@ -1,4 +1,5 @@ use crate::error::Result; +use crate::mark_success; use crate::message::Message; use crate::pipeline::isb::reader::ISBReaderOrchestrator; use crate::reduce::wal::WalMessage; @@ -137,20 +138,22 @@ impl PBQ { .streaming_write(ReceiverStream::new(wal_rx)) .await?; - while let Some(msg) = isb_stream.next().await { - // Send the message to WAL - it will be converted to bytes internally. - // The message will be kept alive until the write completes, then dropped - // (triggering ack via Arc). + while let Some(read_msg) = isb_stream.next().await { + // Clone the message for downstream before passing MessageHandle to WAL + let message = read_msg.message().clone(); + + // Send the MessageHandle to WAL - WAL will ACK after successful write wal_tx .send(SegmentWriteMessage::WriteMessage { - message: msg.clone(), + read_message: read_msg, }) .await .map_err(|_| { crate::error::Error::Reduce("PBQ WAL writer: receiver dropped".to_string()) })?; - tx.send(msg).await.map_err(|_| { + // Send cloned message downstream + tx.send(message).await.map_err(|_| { crate::error::Error::Reduce( "PBQ ISB reader: downstream receiver dropped".to_string(), ) @@ -188,13 +191,16 @@ impl PBQ { let (mut isb_stream, isb_handle) = isb_reader.streaming_read(cancellation_token).await?; // Process messages from ISB stream - while let Some(msg) = isb_stream.next().await { - // Forward the message to the output channel - tx.send(msg).await.map_err(|_| { + while let Some(read_msg) = isb_stream.next().await { + // Extract the message and forward to output channel + let message = read_msg.message().clone(); + tx.send(message).await.map_err(|_| { crate::error::Error::Reduce( "PBQ ISB reader: downstream receiver dropped".to_string(), ) })?; + // Mark the read message as success after processing + mark_success!(read_msg); } // Wait for the ISB reader task to complete @@ -615,9 +621,11 @@ mod tests { ..Default::default() }; - tx.send(SegmentWriteMessage::WriteMessage { message }) - .await - .unwrap(); + tx.send(SegmentWriteMessage::WriteMessage { + read_message: message.into(), + }) + .await + .unwrap(); } drop(tx); diff --git a/rust/numaflow-core/src/reduce/reducer/aligned/user_defined.rs b/rust/numaflow-core/src/reduce/reducer/aligned/user_defined.rs index 034606c4ee..83a8dd4e06 100644 --- a/rust/numaflow-core/src/reduce/reducer/aligned/user_defined.rs +++ b/rust/numaflow-core/src/reduce/reducer/aligned/user_defined.rs @@ -1,5 +1,5 @@ use crate::config::{get_vertex_name, get_vertex_replica}; -use crate::message::{IntOffset, Message, MessageID, Offset}; +use crate::message::{IntOffset, Message, MessageHandle, MessageID, Offset}; use crate::reduce::reducer::aligned::windower::{AlignedWindowMessage, AlignedWindowOperation}; use crate::shared::grpc::{prost_timestamp_from_utc, utc_from_timestamp}; use crate::{Result, jh_abort_guard}; @@ -171,7 +171,7 @@ impl UserDefinedAlignedReduce { pub(crate) async fn reduce_fn( &mut self, stream: ReceiverStream, - result_tx: tokio::sync::mpsc::Sender, + result_tx: tokio::sync::mpsc::Sender, cln_token: CancellationToken, ) -> Result<()> { // Convert AlignedWindowMessage stream to ReduceRequest stream @@ -245,8 +245,9 @@ impl UserDefinedAlignedReduce { } .into(); + // Send to ISB writer (message is converted to MessageHandle via From trait) result_tx - .send(message) + .send(message.into()) .await .expect("failed to send response"); @@ -456,8 +457,11 @@ mod tests { .expect("no result received"); // Verify the result - assert_eq!(result.keys.to_vec(), vec!["key1"]); - assert_eq!(String::from_utf8(result.value.to_vec()).unwrap(), "3"); // Counter should be 3 + assert_eq!(result.message().keys.to_vec(), vec!["key1"]); + assert_eq!( + String::from_utf8(result.message().value.to_vec()).unwrap(), + "3" + ); // Counter should be 3 // Shutdown the server shutdown_tx @@ -682,17 +686,23 @@ mod tests { assert_eq!(results.len(), 2); // Sort results by key for deterministic testing - results.sort_by_key(|a| a.keys.to_vec()); + results.sort_by_key(|a| a.message().keys.to_vec()); // Check key1 result let result0 = results.first().expect("Expected result for key1"); - assert_eq!(result0.keys.to_vec(), vec!["key1"]); - assert_eq!(String::from_utf8(result0.value.to_vec()).unwrap(), "2"); // Counter should be 2 + assert_eq!(result0.message().keys.to_vec(), vec!["key1"]); + assert_eq!( + String::from_utf8(result0.message().value.to_vec()).unwrap(), + "2" + ); // Counter should be 2 // Check key2 result let result1 = results.get(1).expect("Expected result for key2"); - assert_eq!(result1.keys.to_vec(), vec!["key2"]); - assert_eq!(String::from_utf8(result1.value.to_vec()).unwrap(), "3"); // Counter should be 3 + assert_eq!(result1.message().keys.to_vec(), vec!["key2"]); + assert_eq!( + String::from_utf8(result1.message().value.to_vec()).unwrap(), + "3" + ); // Counter should be 3 // Shutdown the server shutdown_tx @@ -846,8 +856,11 @@ mod tests { .expect("no result received"); // Verify the result - assert_eq!(result.keys.to_vec(), vec!["key1"]); - assert_eq!(String::from_utf8(result.value.to_vec()).unwrap(), "3"); // Counter should be 3 + assert_eq!(result.message().keys.to_vec(), vec!["key1"]); + assert_eq!( + String::from_utf8(result.message().value.to_vec()).unwrap(), + "3" + ); // Counter should be 3 // Shutdown the server shutdown_tx diff --git a/rust/numaflow-core/src/reduce/reducer/unaligned/reducer.rs b/rust/numaflow-core/src/reduce/reducer/unaligned/reducer.rs index cebf11ee2c..4396e9a76d 100644 --- a/rust/numaflow-core/src/reduce/reducer/unaligned/reducer.rs +++ b/rust/numaflow-core/src/reduce/reducer/unaligned/reducer.rs @@ -169,8 +169,10 @@ impl ReduceTask { // accumulator acts like a Global Window. let window = response.window.clone().expect("Window not set in response"); let window : Window = window.into(); + // Send to ISB writer (message is converted to MessageHandle via From trait) + let message: Message = response.into(); writer_tx - .send(response.into()) + .send(message.into()) .await .expect("Failed to send response to writer"); @@ -284,8 +286,10 @@ impl ReduceTask { continue; } + // Send to ISB writer (message is converted to MessageHandle via From trait) + let message: Message = response.into(); writer_tx - .send(response.into()) + .send(message.into()) .await .expect("Failed to send response to writer"); } diff --git a/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/accumulator.rs b/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/accumulator.rs index b95b3fca24..7ef9a6bba7 100644 --- a/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/accumulator.rs +++ b/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/accumulator.rs @@ -106,7 +106,6 @@ impl From for Message { headers: Arc::new(result.headers), metadata: None, is_late: false, - ack_handle: None, } } } diff --git a/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/session.rs b/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/session.rs index dd6f740043..f5e00f53cd 100644 --- a/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/session.rs +++ b/rust/numaflow-core/src/reduce/reducer/unaligned/user_defined/session.rs @@ -155,7 +155,6 @@ impl From for Message { headers: Arc::new(HashMap::new()), // reset headers since it is a new message metadata: None, is_late: false, - ack_handle: None, } } } diff --git a/rust/numaflow-core/src/reduce/wal/segment.rs b/rust/numaflow-core/src/reduce/wal/segment.rs index 5492026f68..84156465d0 100644 --- a/rust/numaflow-core/src/reduce/wal/segment.rs +++ b/rust/numaflow-core/src/reduce/wal/segment.rs @@ -202,9 +202,11 @@ mod tests { }; // Send message - conversion to bytes happens internally - tx.send(SegmentWriteMessage::WriteMessage { message }) - .await - .unwrap(); + tx.send(SegmentWriteMessage::WriteMessage { + read_message: message.into(), + }) + .await + .unwrap(); } drop(tx); @@ -351,9 +353,11 @@ mod tests { }; // Send message - conversion to bytes happens internally - tx.send(SegmentWriteMessage::WriteMessage { message }) - .await - .unwrap(); + tx.send(SegmentWriteMessage::WriteMessage { + read_message: message.into(), + }) + .await + .unwrap(); } drop(tx); diff --git a/rust/numaflow-core/src/reduce/wal/segment/append.rs b/rust/numaflow-core/src/reduce/wal/segment/append.rs index 1f5691811c..5ab4d8d6b2 100644 --- a/rust/numaflow-core/src/reduce/wal/segment/append.rs +++ b/rust/numaflow-core/src/reduce/wal/segment/append.rs @@ -1,11 +1,11 @@ -use crate::message::Message; +use crate::mark_success; +use crate::message::MessageHandle; use crate::reduce::wal::error::WalResult; use crate::reduce::wal::segment::WalType; use bytes::Bytes; use chrono::{DateTime, Utc}; use std::path::Path; use std::path::PathBuf; -use std::sync::atomic::Ordering; use tokio::io::BufWriter; use tokio::task::JoinHandle; use tokio::{ @@ -23,12 +23,8 @@ const ROTATE_IF_STALE_DURATION: chrono::Duration = chrono::Duration::seconds(30) /// The Command that has to be operated on the Segment. pub(crate) enum SegmentWriteMessage { /// Writes a message to the WAL. The message will be converted to bytes internally. - /// The message is kept alive until the write completes, ensuring Arc is not - /// dropped prematurely. - WriteMessage { - /// Message to be written. Will be dropped after successful write. - message: Message, - }, + /// After successful write, mark_success() is called on the MessageHandle to ACK. + WriteMessage { read_message: MessageHandle }, /// Writes GC Events to the WAL WriteGcEvent { /// Raw data to be written to the WAL. @@ -138,33 +134,32 @@ impl SegmentWriteActor { /// we should exit upon errors. async fn handle_message(&mut self, msg: SegmentWriteMessage) -> WalResult<()> { match msg { - SegmentWriteMessage::WriteMessage { message } => { + SegmentWriteMessage::WriteMessage { read_message } => { // Convert message to bytes let data: Bytes = crate::reduce::wal::WalMessage { - message: message.clone(), + message: read_message.message().clone(), } .try_into() .expect("Failed to convert message to bytes"); - // Message is dropped here after successful write, triggering ack/nack. - return match self.write_data(data).await { - Ok(_) => Ok(()), + // Write to WAL and ACK on success + match self.write_data(data).await { + Ok(_) => { + // Successfully written to WAL, ACK the message + mark_success!(read_message); + Ok(()) + } Err(e) => { - // message failed to write to WAL, mark it as failed so that it gets nacked. error!(?e, "Failed to write message to WAL"); - message - .ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + read_message.mark_failed(&e); Err(e) } - }; + } } SegmentWriteMessage::WriteGcEvent { data } => { // Just write the raw data self.write_data(data).await?; + Ok(()) } SegmentWriteMessage::Rotate { on_size } => { // Rotate if forced (`on_size` is false) OR if size threshold is met @@ -177,9 +172,9 @@ impl SegmentWriteActor { "Skipping rotation: size threshold not met and not forced." ); } + Ok(()) } } - Ok(()) } /// Writes the data to the Segment. @@ -419,7 +414,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg1 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg1.into(), + }) .await .unwrap(); @@ -429,7 +426,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg2 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg2.into(), + }) .await .unwrap(); @@ -441,7 +440,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg3 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg3.into(), + }) .await .unwrap(); @@ -452,7 +453,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg4 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg4.into(), + }) .await .unwrap(); @@ -462,7 +465,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg5 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg5.into(), + }) .await .unwrap(); @@ -472,7 +477,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg6 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg6.into(), + }) .await .unwrap(); @@ -538,7 +545,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg1 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg1.into(), + }) .await .unwrap(); @@ -594,7 +603,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg1 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg1.into(), + }) .await .unwrap(); @@ -611,7 +622,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg2 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg2.into(), + }) .await .unwrap(); @@ -683,7 +696,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg1 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg1.into(), + }) .await .unwrap(); @@ -697,7 +712,9 @@ mod tests { ..Default::default() }; wal_tx - .send(SegmentWriteMessage::WriteMessage { message: msg2 }) + .send(SegmentWriteMessage::WriteMessage { + read_message: msg2.into(), + }) .await .unwrap(); diff --git a/rust/numaflow-core/src/reduce/wal/segment/compactor.rs b/rust/numaflow-core/src/reduce/wal/segment/compactor.rs index 6c987ecfab..5c6365f08e 100644 --- a/rust/numaflow-core/src/reduce/wal/segment/compactor.rs +++ b/rust/numaflow-core/src/reduce/wal/segment/compactor.rs @@ -779,9 +779,11 @@ mod tests { }; // Send message - conversion to bytes happens internally - tx.send(SegmentWriteMessage::WriteMessage { message }) - .await - .unwrap(); + tx.send(SegmentWriteMessage::WriteMessage { + read_message: message.into(), + }) + .await + .unwrap(); // Rotate every 100 messages to create 10 files if i % 100 == 0 { @@ -954,13 +956,13 @@ mod tests { // Write the messages to the WAL - conversion to bytes happens internally tx.send(SegmentWriteMessage::WriteMessage { - message: before_message, + read_message: before_message.into(), }) .await .map_err(|e| format!("Failed to send data: {e}"))?; tx.send(SegmentWriteMessage::WriteMessage { - message: after_message, + read_message: after_message.into(), }) .await .map_err(|e| format!("Failed to send data: {e}"))?; @@ -1210,7 +1212,7 @@ mod tests { { // Send message - conversion to bytes happens internally tx.send(SegmentWriteMessage::WriteMessage { - message: message.clone(), + read_message: message.clone().into(), }) .await .map_err(|e| format!("Failed to send data: {e}"))?; diff --git a/rust/numaflow-core/src/sinker/sink.rs b/rust/numaflow-core/src/sinker/sink.rs index eb6c70b945..fb88897b6b 100644 --- a/rust/numaflow-core/src/sinker/sink.rs +++ b/rust/numaflow-core/src/sinker/sink.rs @@ -1,13 +1,14 @@ -use crate::Result; use crate::config::pipeline::VERTEX_TYPE_SINK; use crate::config::{get_vertex_name, is_mono_vertex}; use crate::error::Error; -use crate::message::Message; +use crate::mark_success_batch; +use crate::message::{Message, MessageHandle}; use crate::metrics::{ PIPELINE_PARTITION_NAME_LABEL, monovertex_metrics, mvtx_forward_metric_labels, pipeline_drop_metric_labels, pipeline_metric_labels, pipeline_metrics, }; use crate::sinker::actor::{SinkActorMessage, SinkActorResponse}; +use crate::{Result, mark_failed_batch}; use numaflow_kafka::sink::KafkaSink; use numaflow_pb::clients::sink::Status::{Failure, Fallback, OnSuccess, Serve, Success}; use numaflow_pb::clients::sink::sink_client::SinkClient; @@ -15,7 +16,6 @@ use numaflow_pb::clients::sink::sink_response; use numaflow_pulsar::sink::Sink as PulsarSink; use numaflow_sqs::sink::SqsSink; use serving::{DEFAULT_ID_HEADER, DEFAULT_POD_HASH_KEY}; -use std::sync::atomic::Ordering; use std::time::Duration; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; @@ -197,7 +197,7 @@ impl SinkWriter { /// closed or the cancellation token is triggered. pub(crate) async fn streaming_write( mut self, - messages_stream: ReceiverStream, + messages_stream: ReceiverStream, cln_token: CancellationToken, ) -> Result>> { Ok(tokio::spawn({ @@ -210,63 +210,30 @@ impl SinkWriter { pin!(chunk_stream); // Main processing loop - while let Some(batch) = chunk_stream.next().await { - // If bypass conditions exist for primary sink, drop the batch - let batch = if let Some(conditions) = &self.bypass_conditions - && let Some(ref _sink) = conditions.sink - { - vec![] - } else { - batch - }; - - // we are in shutting down mode, we will not be writing to the sink, - // mark the messages as failed, and on Drop they will be nack'ed. + while let Some(read_batch) = chunk_stream.next().await { + // We are in shutting down mode, NAK all messages if self.shutting_down_on_err { - for msg in &batch { - msg.ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + for msg in read_batch { + msg.mark_failed(self.final_result.as_ref().unwrap_err()); } continue; } - // collect ack handles for later failure tracking - let ack_handles = batch - .iter() - .map(|msg| msg.ack_handle.clone()) - .collect::>(); - - let mut dropped_message_count = batch.len(); - // filter out messages that are marked for drop - let batch: Vec<_> = batch.into_iter().filter(|msg| !msg.dropped()).collect(); - dropped_message_count -= batch.len(); - - // skip if all were dropped - if batch.is_empty() { - continue; - } - - // perform the write operation - if let Err(e) = self.write_to_sink(batch, cln_token.clone()).await { - // critical error, cancel upstream and mark all acks as failed - error!(?e, "Error writing to sink, initiating shutdown."); - cln_token.cancel(); - - for ack_handle in ack_handles { - ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); + let messages = read_batch.iter().map(|msg| msg.message().clone()).collect(); + match self.process_batch(messages, cln_token.clone()).await { + Ok(()) => { + // Batch processed successfully + mark_success_batch!(read_batch); + } + Err(e) => { + mark_failed_batch!(read_batch, &e); + // Critical error, cancel upstream and initiate shutdown + error!(?e, "Error writing to sink, initiating shutdown."); + cln_token.cancel(); + self.final_result = Err(e); + self.shutting_down_on_err = true; } - - self.final_result = Err(e); - self.shutting_down_on_err = true; } - send_drop_metrics(is_mono_vertex(), dropped_message_count); } // finalize @@ -275,6 +242,40 @@ impl SinkWriter { })) } + /// Processes a batch of messages: handles bypass, dropped messages, and writes to sink. + /// On success, all messages are ACK'd. On error, messages are NAK'd (dropped without ack). + async fn process_batch( + &mut self, + messages: Vec, + cln_token: CancellationToken, + ) -> Result<()> { + // If bypass conditions exist for primary sink, ack and skip + if let Some(conditions) = &self.bypass_conditions + && let Some(ref _sink) = conditions.sink + { + return Ok(()); + } + + // Separate dropped messages from messages to process + let (dropped, to_process): (Vec<_>, Vec<_>) = messages + .into_iter() + .partition(|read_msg| read_msg.dropped()); + + let dropped_count = dropped.len(); + + // If all messages were dropped, we're done + if to_process.is_empty() { + send_drop_metrics(is_mono_vertex(), dropped_count); + return Ok(()); + } + + // Perform the write operation + self.write_to_sink(to_process, cln_token.clone()).await?; + + send_drop_metrics(is_mono_vertex(), dropped_count); + Ok(()) + } + /// Write the messages to the Sink. /// Invokes the primary sink actor, handles fallback messages, serving messages, and errors. pub(crate) async fn write_to_sink( @@ -719,7 +720,7 @@ impl From for ResponseFromSink { mod tests { use super::*; use crate::config::pipeline::NatsStoreConfig; - use crate::message::{AckHandle, IntOffset, Message, MessageID, Offset, ReadAck}; + use crate::message::{IntOffset, Message, MessageHandle, MessageID, Offset, ReadAck}; use crate::shared::grpc::create_rpc_channel; use crate::sinker::sink::serve::nats::NatsServingStore; use crate::tracker::Tracker; @@ -807,11 +808,11 @@ mod tests { .unwrap(); let mut ack_rxs = vec![]; - let messages: Vec = (0..10) + let messages: Vec = (0..10) .map(|i| { let (ack_tx, ack_rx) = oneshot::channel(); ack_rxs.push(ack_rx); - Message { + let message = Message { typ: Default::default(), keys: Arc::from(vec![format!("key_{}", i)]), tags: None, @@ -824,9 +825,9 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() - } + }; + MessageHandle::new(message, ack_tx) }) .collect(); @@ -886,11 +887,11 @@ mod tests { .unwrap(); let mut ack_rxs = vec![]; - let messages: Vec = (0..10) + let messages: Vec = (0..10) .map(|i| { let (ack_tx, ack_rx) = oneshot::channel(); ack_rxs.push(ack_rx); - Message { + let message = Message { typ: Default::default(), keys: Arc::from(vec!["error".to_string()]), tags: None, @@ -903,9 +904,9 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() - } + }; + MessageHandle::new(message, ack_tx) }) .collect(); @@ -975,11 +976,11 @@ mod tests { .unwrap(); let mut ack_rxs = vec![]; - let messages: Vec = (0..20) + let messages: Vec = (0..20) .map(|i| { let (ack_tx, ack_rx) = oneshot::channel(); ack_rxs.push(ack_rx); - Message { + let message = Message { typ: Default::default(), keys: Arc::from(vec!["fallback".to_string()]), tags: None, @@ -992,9 +993,9 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() - } + }; + MessageHandle::new(message, ack_tx) }) .collect(); @@ -1057,11 +1058,11 @@ mod tests { .unwrap(); let mut ack_rxs = vec![]; - let messages: Vec = (0..20) + let messages: Vec = (0..20) .map(|i| { let (ack_tx, ack_rx) = oneshot::channel(); ack_rxs.push(ack_rx); - Message { + let message = Message { typ: Default::default(), keys: Arc::from(vec!["onSuccess".to_string()]), tags: None, @@ -1074,9 +1075,9 @@ mod tests { offset: format!("offset_{}", i).into(), index: i as i32, }, - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() - } + }; + MessageHandle::new(message, ack_tx) }) .collect(); @@ -1167,14 +1168,14 @@ mod tests { .unwrap(); let mut ack_rxs = vec![]; - let messages: Vec = (0..10) + let read_messages: Vec = (0..10) .map(|i| { let mut headers = HashMap::new(); headers.insert(DEFAULT_ID_HEADER.to_string(), format!("id_{}", i)); headers.insert(DEFAULT_POD_HASH_KEY.to_string(), "abcd".to_string()); let (ack_tx, ack_rx) = oneshot::channel(); ack_rxs.push(ack_rx); - Message { + let message = Message { typ: Default::default(), keys: Arc::from(vec!["serve".to_string()]), tags: None, @@ -1188,14 +1189,14 @@ mod tests { index: i as i32, }, headers: Arc::new(headers), - ack_handle: Some(Arc::new(AckHandle::new(ack_tx))), ..Default::default() - } + }; + MessageHandle::new(message, ack_tx) }) .collect(); let (tx, rx) = mpsc::channel(10); - for msg in messages { + for msg in read_messages { let _ = tx.send(msg).await; } diff --git a/rust/numaflow-core/src/sinker/sink/sqs.rs b/rust/numaflow-core/src/sinker/sink/sqs.rs index 3c919b3c17..36552c83b9 100644 --- a/rust/numaflow-core/src/sinker/sink/sqs.rs +++ b/rust/numaflow-core/src/sinker/sink/sqs.rs @@ -162,7 +162,6 @@ mod unit_tests { headers: Arc::new(headers.clone()), metadata: None, is_late: false, - ack_handle: None, }; let sink_msg: SqsSinkMessage = msg.try_into().unwrap(); @@ -221,7 +220,6 @@ mod unit_tests { headers: Arc::new(headers), metadata: Some(Arc::new(metadata)), is_late: false, - ack_handle: None, }; let sink_msg: SqsSinkMessage = msg.try_into().unwrap(); @@ -285,7 +283,6 @@ mod unit_tests { headers: Arc::new(headers), metadata: Some(Arc::new(metadata)), is_late: false, - ack_handle: None, }; let sink_msg: SqsSinkMessage = msg.try_into().unwrap(); diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index bf7d90827d..1b5fee092f 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -7,7 +7,7 @@ use crate::config::pipeline::VERTEX_TYPE_SOURCE; use crate::config::{get_vertex_name, is_mono_vertex}; use crate::error::{Error, Result}; -use crate::message::{AckHandle, ReadAck}; +use crate::message::{MessageHandle, ReadAck}; use crate::metrics::{ PIPELINE_PARTITION_NAME_LABEL, SOURCE_PARTITION_NAME_LABEL, monovertex_metrics, mvtx_forward_metric_labels, pipeline_metric_labels, pipeline_metrics, @@ -29,7 +29,6 @@ use numaflow_pulsar::source::PulsarSource; use numaflow_sqs::source::SqsSource; use numaflow_throttling::RateLimiter; use std::sync::Arc; -use std::sync::atomic::Ordering; use tokio::sync::OwnedSemaphorePermit; use tokio::sync::Semaphore; use tokio::sync::{mpsc, oneshot}; @@ -67,7 +66,7 @@ pub(crate) mod kafka; pub(crate) mod test_utils; use crate::transformer::Transformer; -use crate::watermark::source::SourceWatermarkHandle; +use crate::watermark::source::{SourceWatermarkEntry, SourceWatermarkHandle}; const MAX_ACK_PENDING: usize = 10000; const ACK_RETRY_INTERVAL: u64 = 100; @@ -448,13 +447,13 @@ impl Source { .map_err(|e| Error::ActorPatternRecv(e.to_string()))? } - /// Starts streaming messages from the source. It returns a stream of messages and + /// Starts streaming messages from the source. It returns a stream of MessageHandles and /// a handle to the spawned task. pub(crate) fn streaming_read( mut self, cln_token: CancellationToken, bypass_router: Option, - ) -> Result<(ReceiverStream, JoinHandle>)> { + ) -> Result<(ReceiverStream, JoinHandle>)> { let (messages_tx, messages_rx) = mpsc::channel(2 * self.read_batch_size); let mut pipeline_labels = pipeline_metric_labels(VERTEX_TYPE_SOURCE).clone(); @@ -500,7 +499,7 @@ impl Source { } let read_start_time = Instant::now(); - let mut messages = match Self::read(self.sender.clone()).await { + let messages = match Self::read(self.sender.clone()).await { Some(Ok(messages)) => messages, None => { info!("Source returned None (end of stream). Stopping the source."); @@ -517,9 +516,10 @@ impl Source { let read_time = read_start_time.elapsed().as_micros() as f64; Self::record_batch_read_metrics(&pipeline_labels, mvtx_labels, read_time, msgs_len); - let mut ack_handles = vec![]; + let mut msg_handles = vec![]; let mut ack_batch = Vec::with_capacity(msgs_len); - for message in messages.iter_mut() { + + for message in messages.iter() { Self::record_partition_read_metrics( &pipeline_labels, mvtx_labels, @@ -527,15 +527,13 @@ impl Source { message.value.len(), ); - let (resp_ack_tx, resp_ack_rx) = oneshot::channel(); - message.ack_handle = Some(Arc::new(AckHandle::new(resp_ack_tx))); - - // insert the offset and the ack one shot in the tracker. + // insert the message (with offset) into the tracker. self.tracker.insert(message).await?; - // store the ack one shot in the batch to invoke ack later. - ack_batch.push((message.offset.clone(), resp_ack_rx)); - ack_handles.push(message.ack_handle.clone()); + let (ack_tx, ack_rx) = oneshot::channel(); + // store the ack receiver in the batch to invoke ack later. + ack_batch.push((message.offset.clone(), ack_rx)); + msg_handles.push(MessageHandle::new(message.clone(), ack_tx)); } // start a background task to invoke ack on the source for the offsets that are acked. @@ -569,25 +567,21 @@ impl Source { // transform the batch if the transformer is present, this need not // be streaming because transformation should be fast operation. - let mut messages = match self.transformer.as_mut() { - None => messages, + // transform_batch accepts MessageHandles and returns MessageHandles with ack + // tracking preserved — flatmap outputs share the parent's ack handle. + let mut msg_handles = match self.transformer.as_mut() { + None => msg_handles, Some(transformer) => match transformer - .transform_batch(messages, cln_token.clone()) + .transform_batch(msg_handles, cln_token.clone()) .await { - Ok(messages) => messages, + Ok(handles) => handles, Err(e) => { error!( ?e, "Error while transforming messages, sending nack to the batch" ); - for ack_handle in ack_handles { - ack_handle - .as_ref() - .expect("ack handle should be present") - .is_failed - .store(true, Ordering::Relaxed); - } + // handles dropped without mark_success, causing NAK result = Err(e); break; } @@ -595,36 +589,45 @@ impl Source { }; if let Some(watermark_handle) = self.watermark_handle.as_mut() { + let entries: Vec = + msg_handles.iter().map(SourceWatermarkEntry::from).collect(); watermark_handle - .generate_and_publish_source_watermark(&messages) + .generate_and_publish_source_watermark(&entries) .await; let watermark = watermark_handle.fetch_source_watermark().await; - // compare with the event time of the message and set is_late - for message in messages.iter_mut() { - message.is_late = message.event_time < watermark; + // set is_late on messages that arrived after the watermark + for msg_handle in msg_handles.iter_mut() { + if msg_handle.message().event_time < watermark { + msg_handle.message_mut().is_late = true; + } } } - // write the messages to downstream. - for message in messages { - let bypassed = if let Some(ref bypass_router) = bypass_router { - bypass_router - .try_bypass(message.clone()) + // write the messages to downstream as MessageHandles. + for read_message in msg_handles.into_iter() { + // Try to bypass the message. If bypassed, try_bypass takes ownership and returns None. + // If not bypassed, it returns Some(read_message) for us to send downstream. + let read_message = if let Some(ref bypass_router) = bypass_router { + match bypass_router + .try_bypass(read_message) .await .expect("failed to send message to bypass channel") + { + Some(msg) => msg, + None => continue, // Message was bypassed, move to next + } } else { - false + read_message }; - if !bypassed { - messages_tx - .send(message) - .await - .expect("send should not fail"); - } + messages_tx + .send(read_message) + .await + .expect("send should not fail"); } } + info!(status=?result, "Source stopped, waiting for inflight messages to be acked/nacked"); // wait for all the ack tasks to be completed before stopping the source, since we give // a permit for each ack task all the permits should be released when the ack tasks are @@ -897,6 +900,7 @@ impl Source { #[cfg(test)] mod tests { + use crate::mark_success; use crate::shared::grpc::create_rpc_channel; use crate::source::user_defined::new_source; use crate::source::{Source, SourceType}; @@ -1076,12 +1080,14 @@ mod tests { // we should read all the 100 messages for i in 0..100 { let message = stream.next().await.unwrap(); - assert_eq!(message.value, "hello".as_bytes()); - offsets.push(message.offset.clone()); + assert_eq!(message.message.value, "hello".as_bytes()); + offsets.push(message.message.offset.clone()); - // only store the last 50 messages, rest will be dropped and acknowledged. + // store last 50 messages; ACK the first 50 explicitly. if i >= 50 { messages.push(message); + } else { + mark_success!(message); } } @@ -1105,19 +1111,15 @@ mod tests { .unwrap(); assert_eq!(source_partitions.active_partitions, vec![1, 2]); - for message in messages.into_iter() { - // set failed to true so that the message is nacked - message - .ack_handle - .unwrap() - .is_failed - .store(true, Ordering::Relaxed); - } + // Drop messages without calling mark_success() to cause NAK + drop(messages); // read should return 50 nacked messages for _ in 0..50 { let message = stream.next().await.unwrap(); - assert_eq!(message.value, "hello".as_bytes()); + assert_eq!(message.message.value, "hello".as_bytes()); + // Mark as success so they get ACK'd (pending goes to 0) + mark_success!(message); } // pending should be 0 now diff --git a/rust/numaflow-core/src/source/generator.rs b/rust/numaflow-core/src/source/generator.rs index be17555956..870021d000 100644 --- a/rust/numaflow-core/src/source/generator.rs +++ b/rust/numaflow-core/src/source/generator.rs @@ -198,7 +198,6 @@ mod stream_generator { // Set default metadata so that metadata is always present. metadata: Some(Arc::new(crate::metadata::Metadata::default())), is_late: false, - ack_handle: None, } } diff --git a/rust/numaflow-core/src/source/http.rs b/rust/numaflow-core/src/source/http.rs index ceb5555ba8..05509364b9 100644 --- a/rust/numaflow-core/src/source/http.rs +++ b/rust/numaflow-core/src/source/http.rs @@ -41,7 +41,6 @@ impl From for Message { // Set default metadata so that metadata is always present. metadata: Some(Arc::new(Metadata::default())), is_late: false, - ack_handle: None, } } } diff --git a/rust/numaflow-core/src/source/jetstream.rs b/rust/numaflow-core/src/source/jetstream.rs index 9757f13dbc..7378746df0 100644 --- a/rust/numaflow-core/src/source/jetstream.rs +++ b/rust/numaflow-core/src/source/jetstream.rs @@ -37,7 +37,6 @@ impl From for Message { // Set default metadata so that metadata is always present. metadata: Some(Arc::new(Metadata::default())), is_late: false, - ack_handle: None, } } } diff --git a/rust/numaflow-core/src/source/kafka.rs b/rust/numaflow-core/src/source/kafka.rs index 2109365bd2..8418a59b03 100644 --- a/rust/numaflow-core/src/source/kafka.rs +++ b/rust/numaflow-core/src/source/kafka.rs @@ -53,7 +53,6 @@ impl TryFrom for Message { // Set default metadata so that metadata is always present. metadata: Some(Arc::new(Metadata::default())), is_late: false, - ack_handle: None, }) } } diff --git a/rust/numaflow-core/src/source/nats.rs b/rust/numaflow-core/src/source/nats.rs index ba45fc7346..068b2d3922 100644 --- a/rust/numaflow-core/src/source/nats.rs +++ b/rust/numaflow-core/src/source/nats.rs @@ -39,7 +39,6 @@ impl From for Message { // Set default metadata so that metadata is always present. metadata: Some(Arc::new(Metadata::default())), is_late: false, - ack_handle: None, } } } diff --git a/rust/numaflow-core/src/source/pulsar.rs b/rust/numaflow-core/src/source/pulsar.rs index aea8a707b4..fec0efd8ae 100644 --- a/rust/numaflow-core/src/source/pulsar.rs +++ b/rust/numaflow-core/src/source/pulsar.rs @@ -32,7 +32,6 @@ impl TryFrom for Message { // Set default metadata so that metadata is always present. metadata: Some(Arc::new(Metadata::default())), is_late: false, - ack_handle: None, }) } } diff --git a/rust/numaflow-core/src/source/sqs.rs b/rust/numaflow-core/src/source/sqs.rs index c723941d47..9e91b19e80 100644 --- a/rust/numaflow-core/src/source/sqs.rs +++ b/rust/numaflow-core/src/source/sqs.rs @@ -49,7 +49,6 @@ impl TryFrom for Message { headers: Arc::new(message.system_attributes), metadata, is_late: false, - ack_handle: None, }) } } diff --git a/rust/numaflow-core/src/source/user_defined.rs b/rust/numaflow-core/src/source/user_defined.rs index b5d7de4a37..22518ad4d9 100644 --- a/rust/numaflow-core/src/source/user_defined.rs +++ b/rust/numaflow-core/src/source/user_defined.rs @@ -179,7 +179,6 @@ impl TryFrom for Message { None => Metadata::default(), })), is_late: false, - ack_handle: None, }) } } diff --git a/rust/numaflow-core/src/tracker.rs b/rust/numaflow-core/src/tracker.rs index db30612be9..893b3e1b83 100644 --- a/rust/numaflow-core/src/tracker.rs +++ b/rust/numaflow-core/src/tracker.rs @@ -493,7 +493,6 @@ mod tests { ..Default::default() })), is_late: false, - ..Default::default() }; // Insert a new message diff --git a/rust/numaflow-core/src/transformer.rs b/rust/numaflow-core/src/transformer.rs index a3af1efb38..0128920cf4 100644 --- a/rust/numaflow-core/src/transformer.rs +++ b/rust/numaflow-core/src/transformer.rs @@ -9,17 +9,17 @@ use tonic::transport::Channel; use tonic::{Code, Status}; use tracing::error; -use crate::Result; use crate::config::pipeline::VERTEX_TYPE_SOURCE; use crate::config::{get_vertex_name, is_mono_vertex}; use crate::error::Error; -use crate::message::Message; +use crate::message::{Message, MessageHandle}; use crate::metrics::{ PIPELINE_PARTITION_NAME_LABEL, monovertex_metrics, mvtx_forward_metric_labels, pipeline_metric_labels, pipeline_metrics, }; use crate::tracker::Tracker; use crate::transformer::user_defined::UserDefinedTransformer; +use crate::{Result, mark_success}; /// User-Defined Transformer is a custom transformer that can be built by the user. /// @@ -155,11 +155,13 @@ impl Transformer { } /// Transforms a batch of messages concurrently. + /// Accepts MessageHandles so that ack tracking flows through to the transformed outputs — + /// each output message shares the ack handle of its parent input (flatmap is handled correctly). pub(crate) async fn transform_batch( &self, - messages: Vec, + msg_handles: Vec, cln_token: CancellationToken, - ) -> Result> { + ) -> Result> { let batch_start_time = tokio::time::Instant::now(); let transform_handle = self.sender.clone(); let tracker = self.tracker.clone(); @@ -195,20 +197,24 @@ impl Transformer { .source_forwarder .transformer_read_total .get_or_create(&labels) - .inc_by(messages.len() as u64); + .inc_by(msg_handles.len() as u64); } - let message_count = messages.len(); + let message_count = msg_handles.len(); - let transform_futs = messages.into_iter().map(|read_msg| { + let transform_futs = msg_handles.into_iter().map(|msg_handle| { let transform_handle = transform_handle.clone(); let tracker = tracker.clone(); let hard_shutdown_token = hard_shutdown_token.clone(); async move { - let offset = read_msg.offset.clone(); - let transformed_messages = - Transformer::transform(transform_handle, read_msg, hard_shutdown_token).await?; + let offset = msg_handle.message().offset.clone(); + let transformed_messages = Transformer::transform( + transform_handle, + msg_handle.message().clone(), + hard_shutdown_token, + ) + .await?; // update the tracker with the number of responses for each message tracker @@ -221,7 +227,15 @@ impl Transformer { ) .await?; - Ok::, Error>(transformed_messages) + // Fan out: each transformed message shares the parent's ack handle. + // mark_success on the parent decrements its ref_count contribution. + let output: Vec = transformed_messages + .into_iter() + .map(|m| msg_handle.with_message(m)) + .collect(); + + mark_success!(msg_handle); + Ok::, Error>(output) } }); @@ -229,11 +243,11 @@ impl Transformer { // This polls up to `concurrency` futures at a time, reducing scheduling overhead. let mut stream = stream::iter(transform_futs).buffered(self.concurrency); - let mut transformed_messages = Vec::with_capacity(message_count * 2); + let mut transformed_handles = Vec::with_capacity(message_count * 2); while let Some(result) = stream.next().await { match result { - Ok(mut msgs) => transformed_messages.append(&mut msgs), + Ok(mut handles) => transformed_handles.append(&mut handles), Err(e) => { // increment transform error metric for pipeline // error here indicates that there was some problem in transformation @@ -252,12 +266,12 @@ impl Transformer { // batch transformation was successful // send transformer metrics - let dropped_messages_count = transformed_messages + let dropped_messages_count = transformed_handles .iter() - .filter(|message| message.dropped()) + .filter(|h| h.message().dropped()) .count(); let elapsed_time = batch_start_time.elapsed().as_micros() as f64; - let write_messages_count = transformed_messages.len() - dropped_messages_count; + let write_messages_count = transformed_handles.len() - dropped_messages_count; Self::send_transformer_metrics( dropped_messages_count, elapsed_time, @@ -267,7 +281,7 @@ impl Transformer { // cleanup the shutdown handle shutdown_handle.abort(); - Ok(transformed_messages) + Ok(transformed_handles) } fn send_transformer_metrics( @@ -335,7 +349,7 @@ mod tests { use super::*; use crate::message::StringOffset; - use crate::message::{Message, MessageID, Offset}; + use crate::message::{Message, MessageHandle, MessageID, Offset}; use crate::shared::grpc::create_rpc_channel; struct SimpleTransformer; @@ -470,7 +484,8 @@ mod tests { }, ..Default::default() }; - messages.push(message); + let (ack_tx, _ack_rx) = tokio::sync::oneshot::channel(); + messages.push(MessageHandle::new(message, ack_tx)); } let transformed_messages = transformer @@ -478,7 +493,7 @@ mod tests { .await?; for (i, transformed_message) in transformed_messages.iter().enumerate() { - assert_eq!(transformed_message.value, format!("value_{}", i)); + assert_eq!(transformed_message.message().value, format!("value_{}", i)); } // we need to drop the transformer, because if there are any in-flight requests @@ -550,8 +565,12 @@ mod tests { ..Default::default() }; + let (ack_tx, _ack_rx) = tokio::sync::oneshot::channel(); let result = transformer - .transform_batch(vec![message], CancellationToken::new()) + .transform_batch( + vec![MessageHandle::new(message, ack_tx)], + CancellationToken::new(), + ) .await; assert!(result.is_err(), "Expected an error due to panic"); assert!(result.unwrap_err().to_string().contains("panic")); diff --git a/rust/numaflow-core/src/transformer/user_defined.rs b/rust/numaflow-core/src/transformer/user_defined.rs index c8b4036f49..361956596b 100644 --- a/rust/numaflow-core/src/transformer/user_defined.rs +++ b/rust/numaflow-core/src/transformer/user_defined.rs @@ -12,7 +12,7 @@ use tonic::{Request, Streaming}; use crate::config::get_vertex_name; use crate::error::{Error, Result}; -use crate::message::{AckHandle, Message, MessageID, Offset}; +use crate::message::{Message, MessageID, Offset}; use crate::metadata::Metadata; use crate::shared::grpc::{prost_timestamp_from_utc, utc_from_timestamp}; @@ -25,7 +25,6 @@ struct ParentMessageInfo { is_late: bool, headers: Arc>, metadata: Option>, - ack_handle: Option>, } // we are passing the reference for msg info because we can have more than 1 response for a single request and @@ -70,7 +69,6 @@ impl From> for Message { Some(Arc::new(metadata)) }, is_late: value.1.is_late, - ack_handle: value.1.ack_handle.clone(), } } } @@ -206,7 +204,6 @@ impl UserDefinedTransformer { headers: Arc::clone(&message.headers), is_late: message.is_late, metadata: message.metadata.clone(), - ack_handle: message.ack_handle.clone(), }; self.senders diff --git a/rust/numaflow-core/src/watermark/source.rs b/rust/numaflow-core/src/watermark/source.rs index 8e57dee2e6..b675054f6b 100644 --- a/rust/numaflow-core/src/watermark/source.rs +++ b/rust/numaflow-core/src/watermark/source.rs @@ -28,7 +28,7 @@ use crate::config::pipeline::isb::Stream; use crate::config::pipeline::watermark::SourceWatermarkConfig; use crate::config::pipeline::{ToVertexConfig, VertexType}; use crate::error::Result; -use crate::message::{IntOffset, Message, Offset}; +use crate::message::{IntOffset, MessageHandle, Offset}; use crate::watermark::idle::isb::ISBIdleDetector; use crate::watermark::idle::source::SourceIdleDetector; use crate::watermark::processor::manager::ProcessorManager; @@ -42,6 +42,26 @@ pub(crate) mod source_wm_fetcher; /// publisher for publishing the source watermark pub(crate) mod source_wm_publisher; +/// The minimal information needed to compute and publish source watermarks. +pub(crate) struct SourceWatermarkEntry { + pub(crate) partition_id: u16, + pub(crate) event_time_ms: i64, +} + +impl From<&MessageHandle> for SourceWatermarkEntry { + fn from(handle: &MessageHandle) -> Self { + let msg = handle.message(); + let partition_id = match &msg.offset { + Offset::Int(o) => o.partition_idx, + Offset::String(o) => o.partition_idx, + }; + Self { + partition_id, + event_time_ms: msg.event_time.timestamp_millis(), + } + } +} + /// Shared state for SourceWatermarkHandle. /// Contains all computation logic and data structures. struct SourceWatermarkState { @@ -76,21 +96,15 @@ impl SourceWatermarkState { /// Handles generating and publishing source watermark with computation async fn generate_and_publish_source_watermark( &mut self, - messages: Vec, + entries: &[SourceWatermarkEntry], ) -> Result<()> { // we need to build a hash-map of the lowest event time for each partition let partition_to_lowest_event_time = - messages.iter().fold(HashMap::new(), |mut acc, message| { - let partition_id = match &message.offset { - Offset::Int(offset) => offset.partition_idx, - Offset::String(offset) => offset.partition_idx, - }; - - let event_time = message.event_time.timestamp_millis(); - - let lowest_event_time = acc.entry(partition_id).or_insert(event_time); - if event_time < *lowest_event_time { - *lowest_event_time = event_time; + entries.iter().fold(HashMap::new(), |mut acc, entry| { + let lowest_event_time = + acc.entry(entry.partition_id).or_insert(entry.event_time_ms); + if entry.event_time_ms < *lowest_event_time { + *lowest_event_time = entry.event_time_ms; } acc }); @@ -313,14 +327,15 @@ impl SourceWatermarkHandle { Ok(source_watermark_handle) } - /// Generates and Publishes the source watermark for the given messages. - pub(crate) async fn generate_and_publish_source_watermark(&self, messages: &[Message]) { + /// Generates and Publishes the source watermark for the given entries. + pub(crate) async fn generate_and_publish_source_watermark( + &self, + entries: &[SourceWatermarkEntry], + ) { // Acquire lock, perform operation, and release immediately let result = { let mut state = self.state.lock().await; - state - .generate_and_publish_source_watermark(messages.to_vec()) - .await + state.generate_and_publish_source_watermark(entries).await }; if let Err(e) = result { @@ -427,14 +442,13 @@ mod tests { use async_nats::jetstream::kv::Config; use async_nats::jetstream::stream; use bytes::BytesMut; - use chrono::DateTime; use tokio::time::sleep; use super::*; use crate::config::pipeline::VertexType; use crate::config::pipeline::isb::BufferWriterConfig; use crate::config::pipeline::watermark::{BucketConfig, IdleConfig}; - use crate::message::{IntOffset, Message}; + use crate::message::IntOffset; use crate::watermark::wmb::WMB; #[cfg(feature = "nats-tests")] @@ -476,27 +490,17 @@ mod tests { .await .expect("Failed to create source watermark handle"); - let messages = vec![ - Message { - offset: Offset::Int(IntOffset { - offset: 1, - partition_idx: 0, - }), - event_time: DateTime::from_timestamp_millis(60000).unwrap(), - ..Default::default() - }, - Message { - offset: Offset::Int(IntOffset { - offset: 2, - partition_idx: 0, - }), - event_time: DateTime::from_timestamp_millis(70000).unwrap(), - ..Default::default() - }, - ]; - handle - .generate_and_publish_source_watermark(&messages) + .generate_and_publish_source_watermark(&[ + SourceWatermarkEntry { + partition_id: 0, + event_time_ms: 60000, + }, + SourceWatermarkEntry { + partition_id: 0, + event_time_ms: 70000, + }, + ]) .await; // try getting the value for the processor from the ot bucket to make sure @@ -626,27 +630,17 @@ mod tests { for i in 1..11 { // publish source watermarks before publishing edge watermarks - let messages = vec![ - Message { - offset: Offset::Int(IntOffset { - offset: 1, - partition_idx: 0, - }), - event_time: DateTime::from_timestamp_millis(10000 * i).unwrap(), - ..Default::default() - }, - Message { - offset: Offset::Int(IntOffset { - offset: 2, - partition_idx: 0, - }), - event_time: DateTime::from_timestamp_millis(20000 * i).unwrap(), - ..Default::default() - }, - ]; - handle - .generate_and_publish_source_watermark(&messages) + .generate_and_publish_source_watermark(&[ + SourceWatermarkEntry { + partition_id: 0, + event_time_ms: 10000 * i, + }, + SourceWatermarkEntry { + partition_id: 0, + event_time_ms: 20000 * i, + }, + ]) .await; let offset = Offset::Int(IntOffset { @@ -932,18 +926,12 @@ mod tests { .await .expect("Failed to create SourceWatermarkHandle"); - let messages = vec![Message { - offset: Offset::Int(IntOffset { - offset: 1, - partition_idx: 0, - }), - event_time: DateTime::from_timestamp_millis(100).unwrap(), - ..Default::default() - }]; - // generate some watermarks to make partition active handle - .generate_and_publish_source_watermark(&messages) + .generate_and_publish_source_watermark(&[SourceWatermarkEntry { + partition_id: 0, + event_time_ms: 100, + }]) .await; // get ot bucket for source and publish some wmb entries