From d27a7728003fd0fff01b10ddf4d0810d8ffe4141 Mon Sep 17 00:00:00 2001 From: Yuval Mandelboum Date: Mon, 16 Mar 2026 11:47:22 -0700 Subject: [PATCH] feat(consumer): Add commit log heartbeat for idle partitions When the events topic has low per-partition throughput, some partitions may have no messages in a given batch window. Since Snuba only produces commit log entries for partitions with data, idle partitions get no entry. This causes the post-process-forwarder (SynchronizedConsumer) to pause those partitions and wait, increasing end-to-end latency with random spikes. On each batch flush, now produce commit log entries for ALL assigned partitions. For idle partitions, re-emit the last known committed offset as a heartbeat. This caps the maximum stall to one batch window (~500ms) instead of being unbounded. Changes: - Python: Add CommitLogHeartbeatState shared across batch writers - Python: ProcessedMessageBatchWriter and MultistorageCollector emit heartbeats for idle partitions in close() - Python: Wire heartbeat state through strategy factory and consumer builder - Rust: ProduceMessage tracks assigned partitions and last offsets, emits heartbeats for idle partitions - Tests: 8 new Python tests + 1 Rust test for heartbeat behavior Co-Authored-By: Claude Opus 4.6 (1M context) --- rust_snuba/src/consumer.rs | 3 +- rust_snuba/src/factory_v2.rs | 9 +- rust_snuba/src/strategies/commit_log.rs | 200 +++++++++- snuba/consumers/consumer.py | 138 +++++-- snuba/consumers/consumer_builder.py | 5 + snuba/consumers/strategy_factory.py | 11 +- tests/test_consumer.py | 473 ++++++++++++++++++++++++ 7 files changed, 800 insertions(+), 39 deletions(-) diff --git a/rust_snuba/src/consumer.rs b/rust_snuba/src/consumer.rs index 6a0c7eaf4d2..6591155cd3b 100644 --- a/rust_snuba/src/consumer.rs +++ b/rust_snuba/src/consumer.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; use chrono::{DateTime, Utc}; @@ -270,6 +270,7 @@ pub fn consumer_impl( join_timeout_ms, health_check: health_check.to_string(), use_row_binary, + assigned_partitions: Arc::new(Mutex::new(Vec::new())), }; let processor = StreamProcessor::with_kafka(config, factory, topic, dlq_policy); diff --git a/rust_snuba/src/factory_v2.rs b/rust_snuba/src/factory_v2.rs index 76404072a24..50c25d3d7f6 100644 --- a/rust_snuba/src/factory_v2.rs +++ b/rust_snuba/src/factory_v2.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; use sentry::{Hub, SentryFutureExt}; @@ -62,6 +62,7 @@ pub struct ConsumerStrategyFactoryV2 { pub join_timeout_ms: Option, pub health_check: String, pub use_row_binary: bool, + pub assigned_partitions: Arc>>, } impl ProcessingStrategyFactory for ConsumerStrategyFactoryV2 { @@ -82,6 +83,8 @@ impl ProcessingStrategyFactory for ConsumerStrategyFactoryV2 { Some(min) => set_global_tag("min_partition".to_owned(), min.to_string()), None => set_global_tag("min_partition".to_owned(), "none".to_owned()), } + + *self.assigned_partitions.lock().unwrap() = assigned_partitions; } fn create(&self) -> Box> { @@ -106,6 +109,7 @@ impl ProcessingStrategyFactory for ConsumerStrategyFactoryV2 { let next_step: Box>> = if let Some((ref producer, destination)) = self.commit_log_producer { + let partitions = self.assigned_partitions.lock().unwrap().clone(); Box::new(ProduceCommitLog::new( next_step, producer.clone(), @@ -114,6 +118,7 @@ impl ProcessingStrategyFactory for ConsumerStrategyFactoryV2 { self.physical_consumer_group.clone(), &self.commitlog_concurrency, false, + partitions, )) } else { Box::new(next_step) @@ -290,6 +295,7 @@ impl ConsumerStrategyFactoryV2 { let next_step: Box>> = if let Some((ref producer, destination)) = self.commit_log_producer { + let partitions = self.assigned_partitions.lock().unwrap().clone(); Box::new(ProduceCommitLog::new( next_step, producer.clone(), @@ -298,6 +304,7 @@ impl ConsumerStrategyFactoryV2 { self.physical_consumer_group.clone(), &self.commitlog_concurrency, false, + partitions, )) } else { Box::new(next_step) diff --git a/rust_snuba/src/strategies/commit_log.rs b/rust_snuba/src/strategies/commit_log.rs index cd6562e451c..9ec81d12fd6 100644 --- a/rust_snuba/src/strategies/commit_log.rs +++ b/rust_snuba/src/strategies/commit_log.rs @@ -10,8 +10,9 @@ use sentry_arroyo::processing::strategies::{ }; use sentry_arroyo::types::{Message, Topic, TopicOrPartition}; use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; use std::str; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; use thiserror::Error; @@ -72,6 +73,8 @@ struct ProduceMessage { topic: Topic, consumer_group: String, skip_produce: bool, + assigned_partitions: Vec, + last_produced_offsets: Arc)>>>, } impl ProduceMessage { @@ -81,6 +84,7 @@ impl ProduceMessage { topic: Topic, consumer_group: String, skip_produce: bool, + assigned_partitions: Vec, ) -> Self { ProduceMessage { producer, @@ -88,6 +92,8 @@ impl ProduceMessage { topic, consumer_group, skip_produce, + assigned_partitions, + last_produced_offsets: Arc::new(Mutex::new(BTreeMap::new())), } } } @@ -102,6 +108,8 @@ impl TaskRunner, BytesInsertBatch<()>, anyhow::Error> for P let topic = self.topic; let skip_produce = self.skip_produce; let consumer_group = self.consumer_group.clone(); + let assigned_partitions = self.assigned_partitions.clone(); + let last_produced_offsets = self.last_produced_offsets.clone(); let commit_log_offsets = message.payload().commit_log_offsets().clone(); @@ -110,6 +118,9 @@ impl TaskRunner, BytesInsertBatch<()>, anyhow::Error> for P return Ok(message); } + let partitions_in_batch: std::collections::BTreeSet = + commit_log_offsets.0.keys().copied().collect(); + for (partition, mut entry) in commit_log_offsets.0 { entry.received_p99.sort(); let received_p99 = entry @@ -132,6 +143,37 @@ impl TaskRunner, BytesInsertBatch<()>, anyhow::Error> for P tracing::error!(error, "Error producing message"); return Err(RunTaskError::RetryableError); } + + // Update last produced offset for this partition + last_produced_offsets + .lock() + .unwrap() + .insert(partition, (entry.offset, entry.orig_message_ts)); + } + + // Produce heartbeat entries for idle assigned partitions + let offsets = last_produced_offsets.lock().unwrap().clone(); + for &partition in &assigned_partitions { + if !partitions_in_batch.contains(&partition) { + if let Some(&(offset, orig_message_ts)) = offsets.get(&partition) { + let commit = Commit { + topic: topic.to_string(), + partition, + group: consumer_group.clone(), + orig_message_ts, + offset, + received_p99: None, + }; + + let payload = commit.try_into().unwrap(); + + if let Err(err) = producer.produce(&destination, payload) { + let error: &dyn std::error::Error = &err; + tracing::error!(error, "Error producing heartbeat message"); + return Err(RunTaskError::RetryableError); + } + } + } } Ok(message) @@ -155,10 +197,18 @@ where consumer_group: String, concurrency: &ConcurrencyConfig, skip_produce: bool, + assigned_partitions: Vec, ) -> Self { let inner = RunTaskInThreads::new( next_step, - ProduceMessage::new(producer, destination, topic, consumer_group, skip_produce), + ProduceMessage::new( + producer, + destination, + topic, + consumer_group, + skip_produce, + assigned_partitions, + ), concurrency, Some("produce_commit_log"), ); @@ -327,6 +377,7 @@ mod tests { "group1".to_string(), &concurrency, false, + vec![0, 1], ); for payload in payloads { @@ -345,4 +396,149 @@ mod tests { assert_eq!(produced[1].0, "test:0:group1"); assert_eq!(produced[2].0, "test:1:group1"); } + + #[test] + fn produce_commit_log_heartbeat_for_idle_partitions() { + // Assigned partitions 0-3, but only partitions 0 and 1 have data. + // After the first batch establishes offsets for partitions 2 and 3, + // a second batch with only partitions 0 and 1 should still produce + // heartbeat entries for partitions 2 and 3. + let produced_payloads = Arc::new(Mutex::new(Vec::new())); + + struct MockProducer { + pub payloads: Arc>>, + } + + impl Producer for MockProducer { + fn produce( + &self, + topic: &TopicOrPartition, + payload: KafkaPayload, + ) -> Result<(), ProducerError> { + assert_eq!(topic.topic().as_str(), "test-commitlog"); + self.payloads.lock().unwrap().push(( + str::from_utf8(payload.key().unwrap()).unwrap().to_owned(), + payload, + )); + Ok(()) + } + } + + // Batch 1: all 4 partitions have data (establishes last known offsets) + let batch1 = BytesInsertBatch::from_rows(()) + .with_message_timestamp(Utc::now()) + .with_commit_log_offsets(CommitLogOffsets(BTreeMap::from([ + ( + 0, + CommitLogEntry { + offset: 100, + orig_message_ts: Utc::now(), + received_p99: Vec::new(), + }, + ), + ( + 1, + CommitLogEntry { + offset: 200, + orig_message_ts: Utc::now(), + received_p99: Vec::new(), + }, + ), + ( + 2, + CommitLogEntry { + offset: 300, + orig_message_ts: Utc::now(), + received_p99: Vec::new(), + }, + ), + ( + 3, + CommitLogEntry { + offset: 400, + orig_message_ts: Utc::now(), + received_p99: Vec::new(), + }, + ), + ]))); + + // Batch 2: only partitions 0 and 1 have data + let batch2 = BytesInsertBatch::from_rows(()) + .with_message_timestamp(Utc::now()) + .with_commit_log_offsets(CommitLogOffsets(BTreeMap::from([ + ( + 0, + CommitLogEntry { + offset: 500, + orig_message_ts: Utc::now(), + received_p99: Vec::new(), + }, + ), + ( + 1, + CommitLogEntry { + offset: 600, + orig_message_ts: Utc::now(), + received_p99: Vec::new(), + }, + ), + ]))); + + let producer = MockProducer { + payloads: produced_payloads.clone(), + }; + + let next_step = TestStrategy::new(); + + let concurrency = ConcurrencyConfig::new(1); + let mut strategy = ProduceCommitLog::new( + next_step, + Arc::new(producer), + Topic::new("test-commitlog"), + Topic::new("test"), + "group1".to_string(), + &concurrency, + false, + vec![0, 1, 2, 3], // assigned partitions 0-3 + ); + + // Submit and process batch 1 + strategy + .submit(Message::new_any_message(batch1, BTreeMap::new())) + .unwrap(); + strategy.poll().unwrap(); + + // Submit and process batch 2 + strategy + .submit(Message::new_any_message(batch2, BTreeMap::new())) + .unwrap(); + strategy.poll().unwrap(); + + strategy.join(None).unwrap(); + + let produced = produced_payloads.lock().unwrap(); + + // Batch 1: 4 entries (all partitions have data, no heartbeats needed) + // Batch 2: 2 entries (partitions 0,1 data) + 2 heartbeats (partitions 2,3) + assert_eq!(produced.len(), 8); + + // Batch 1 entries + assert_eq!(produced[0].0, "test:0:group1"); + assert_eq!(produced[1].0, "test:1:group1"); + assert_eq!(produced[2].0, "test:2:group1"); + assert_eq!(produced[3].0, "test:3:group1"); + + // Batch 2: partitions 0 and 1 (data), then heartbeats for 2 and 3 + assert_eq!(produced[4].0, "test:0:group1"); + assert_eq!(produced[5].0, "test:1:group1"); + assert_eq!(produced[6].0, "test:2:group1"); + assert_eq!(produced[7].0, "test:3:group1"); + + // Verify heartbeat offsets match the last known offsets from batch 1 + let heartbeat_2: Commit = produced[6].1.clone().try_into().unwrap(); + assert_eq!(heartbeat_2.offset, 300); + + let heartbeat_3: Commit = produced[7].1.clone().try_into().unwrap(); + assert_eq!(heartbeat_3.offset, 400); + } } diff --git a/snuba/consumers/consumer.py b/snuba/consumers/consumer.py index 875efc5128f..c4da0d838ec 100644 --- a/snuba/consumers/consumer.py +++ b/snuba/consumers/consumer.py @@ -62,6 +62,28 @@ class CommitLogConfig(NamedTuple): group_id: str +class CommitLogHeartbeatState: + """ + Shared mutable state that persists across batch writer instances to enable + commit log heartbeats for idle partitions. + + When a partition has no messages in a batch window, the post-process-forwarder + stalls waiting for a commit log entry. This state tracks last known offsets so + we can re-emit them as heartbeats for idle partitions. + """ + + def __init__(self) -> None: + self.assigned_partitions: Set[Partition] = set() + self.last_produced_offsets: MutableMapping[Partition, Tuple[int, datetime]] = {} + + def update_partitions(self, partitions: Mapping[Partition, int]) -> None: + self.assigned_partitions = set(partitions.keys()) + # Clean up offsets for partitions we no longer own + for p in list(self.last_produced_offsets.keys()): + if p not in self.assigned_partitions: + del self.last_produced_offsets[p] + + class BytesInsertBatch(NamedTuple): rows: Sequence[bytes] @@ -239,10 +261,12 @@ def __init__( # upon closing each batch. commit_log_config: Optional[CommitLogConfig] = None, metrics: Optional[MetricsBackend] = None, + heartbeat_state: Optional[CommitLogHeartbeatState] = None, ) -> None: self.__insert_batch_writer = insert_batch_writer self.__replacement_batch_writer = replacement_batch_writer self.__commit_log_config = commit_log_config + self.__heartbeat_state = heartbeat_state self.__offsets_to_produce: MutableMapping[Partition, Tuple[int, datetime]] = {} self.__received_timestamps: MutableMapping[Partition, List[float]] = defaultdict(list) @@ -282,6 +306,32 @@ def __commit_message_delivery_callback( if error is not None: raise Exception(error.str()) + def __produce_commit_log_entry( + self, + partition: Partition, + offset: int, + timestamp: datetime, + received_p99: Optional[float], + ) -> None: + assert self.__commit_log_config is not None + payload = commit_codec.encode( + CommitLogCommit( + self.__commit_log_config.group_id, + partition, + offset, + datetime.timestamp(timestamp), + received_p99, + ) + ) + self.__commit_log_config.producer.produce( + self.__commit_log_config.topic.name, + key=payload.key, + value=payload.value, + headers=payload.headers, + on_delivery=self.__commit_message_delivery_callback, + ) + self.__commit_log_config.producer.poll(0.0) + def close(self) -> None: self.__closed = True @@ -302,23 +352,20 @@ def close(self) -> None: else: received_p99 = None - payload = commit_codec.encode( - CommitLogCommit( - self.__commit_log_config.group_id, - partition, - offset, - datetime.timestamp(timestamp), - received_p99, - ) - ) - self.__commit_log_config.producer.produce( - self.__commit_log_config.topic.name, - key=payload.key, - value=payload.value, - headers=payload.headers, - on_delivery=self.__commit_message_delivery_callback, - ) - self.__commit_log_config.producer.poll(0.0) + self.__produce_commit_log_entry(partition, offset, timestamp, received_p99) + + # Produce heartbeat entries for idle assigned partitions + if self.__heartbeat_state is not None: + # Update shared state with offsets from this batch + self.__heartbeat_state.last_produced_offsets.update(self.__offsets_to_produce) + # Emit heartbeats for assigned partitions that had no messages + for partition in self.__heartbeat_state.assigned_partitions: + if partition not in self.__offsets_to_produce: + last = self.__heartbeat_state.last_produced_offsets.get(partition) + if last is not None: + offset, timestamp = last + self.__produce_commit_log_entry(partition, offset, timestamp, None) + self.__offsets_to_produce.clear() self.__received_timestamps.clear() @@ -345,6 +392,7 @@ def build_batch_writer( replacements_topic: Optional[Topic] = None, commit_log_config: Optional[CommitLogConfig] = None, slice_id: Optional[int] = None, + heartbeat_state: Optional[CommitLogHeartbeatState] = None, ) -> Callable[[], ProcessedMessageBatchWriter]: assert not (replacements_producer is None) ^ (replacements_topic is None) supports_replacements = replacements_producer is not None @@ -373,6 +421,7 @@ def build_writer() -> ProcessedMessageBatchWriter: replacement_batch_writer, commit_log_config, metrics=insert_metrics, + heartbeat_state=heartbeat_state, ) return build_writer @@ -385,10 +434,12 @@ def __init__( # If passed, produces to the commit log after each batch is closed commit_log_config: Optional[CommitLogConfig], ignore_errors: Optional[Set[StorageKey]] = None, + heartbeat_state: Optional[CommitLogHeartbeatState] = None, ): self.__steps = steps self.__closed = False self.__commit_log_config = commit_log_config + self.__heartbeat_state = heartbeat_state self.__messages: MutableMapping[ StorageKey, List[Message[Tuple[StorageKey, Union[None, BytesInsertBatch, ReplacementBatch]]]], @@ -420,6 +471,31 @@ def submit( message.value.timestamp, ) + def __produce_commit_log_entry( + self, + partition: Partition, + offset: int, + timestamp: datetime, + ) -> None: + assert self.__commit_log_config is not None + payload = commit_codec.encode( + CommitLogCommit( + self.__commit_log_config.group_id, + partition, + offset, + datetime.timestamp(timestamp), + None, + ) + ) + self.__commit_log_config.producer.produce( + self.__commit_log_config.topic.name, + key=payload.key, + value=payload.value, + headers=payload.headers, + on_delivery=self.__commit_message_delivery_callback, + ) + self.__commit_log_config.producer.poll(0.0) + def close(self) -> None: self.__closed = True @@ -428,23 +504,17 @@ def close(self) -> None: if self.__commit_log_config is not None: for partition, (offset, timestamp) in self.__offsets_to_produce.items(): - payload = commit_codec.encode( - CommitLogCommit( - self.__commit_log_config.group_id, - partition, - offset, - datetime.timestamp(timestamp), - None, - ) - ) - self.__commit_log_config.producer.produce( - self.__commit_log_config.topic.name, - key=payload.key, - value=payload.value, - headers=payload.headers, - on_delivery=self.__commit_message_delivery_callback, - ) - self.__commit_log_config.producer.poll(0.0) + self.__produce_commit_log_entry(partition, offset, timestamp) + + # Produce heartbeat entries for idle assigned partitions + if self.__heartbeat_state is not None: + self.__heartbeat_state.last_produced_offsets.update(self.__offsets_to_produce) + for partition in self.__heartbeat_state.assigned_partitions: + if partition not in self.__offsets_to_produce: + last = self.__heartbeat_state.last_produced_offsets.get(partition) + if last is not None: + offset, timestamp = last + self.__produce_commit_log_entry(partition, offset, timestamp) self.__commit_log_config.producer.flush() diff --git a/snuba/consumers/consumer_builder.py b/snuba/consumers/consumer_builder.py index 4b95a53f166..6080408c19e 100644 --- a/snuba/consumers/consumer_builder.py +++ b/snuba/consumers/consumer_builder.py @@ -21,6 +21,7 @@ from snuba.consumers.consumer import ( CommitLogConfig, + CommitLogHeartbeatState, build_batch_writer, process_message, ) @@ -218,8 +219,10 @@ def build_streaming_strategy_factory( commit_log_config = CommitLogConfig( self.commit_log_producer, self.commit_log_topic, self.group_id ) + heartbeat_state: Optional[CommitLogHeartbeatState] = CommitLogHeartbeatState() else: commit_log_config = None + heartbeat_state = None strategy_factory: ProcessingStrategyFactory[KafkaPayload] = KafkaConsumerStrategyFactory( prefilter=stream_loader.get_pre_filter(), @@ -237,6 +240,7 @@ def build_streaming_strategy_factory( replacements_topic=self.replacements_topic, slice_id=self.slice_id, commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, ), max_batch_size=self.max_batch_size, max_batch_time=self.max_batch_time_ms / 1000.0, @@ -249,6 +253,7 @@ def build_streaming_strategy_factory( initialize_parallel_transform=setup_sentry, health_check_file=self.health_check_file, metrics_tags=self.metrics_tags, + heartbeat_state=heartbeat_state, ) if self.__profile_path is not None: diff --git a/snuba/consumers/strategy_factory.py b/snuba/consumers/strategy_factory.py index 6b16f9d4e89..40789434e69 100644 --- a/snuba/consumers/strategy_factory.py +++ b/snuba/consumers/strategy_factory.py @@ -19,7 +19,11 @@ ) from arroyo.types import BaseValue, Commit, FilteredPayload, Message, Partition -from snuba.consumers.consumer import BytesInsertBatch, ProcessedMessageBatchWriter +from snuba.consumers.consumer import ( + BytesInsertBatch, + CommitLogHeartbeatState, + ProcessedMessageBatchWriter, +) from snuba.consumers.dlq import ExitAfterNMessages from snuba.processor import ReplacementBatch @@ -76,11 +80,13 @@ def __init__( max_messages_to_process: Optional[int] = None, initialize_parallel_transform: Optional[Callable[[], None]] = None, health_check_file: Optional[str] = None, + heartbeat_state: Optional[CommitLogHeartbeatState] = None, ) -> None: self.__prefilter = prefilter self.__process_message = process_message self.__collector = collector self.__max_messages_to_process = max_messages_to_process + self.__heartbeat_state = heartbeat_state self.__max_batch_size = max_batch_size self.__max_batch_time = max_batch_time @@ -117,6 +123,9 @@ def create_with_partitions( else: self.__metrics_tags.pop("min_partition", None) + if self.__heartbeat_state is not None: + self.__heartbeat_state.update_partitions(partitions) + def accumulator( batch_writer: ProcessedMessageBatchWriter, message: BaseValue[ProcessedMessage], diff --git a/tests/test_consumer.py b/tests/test_consumer.py index b6519c95fe6..dc4841d25ce 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -15,8 +15,11 @@ from snuba.clusters.cluster import ClickhouseClientSettings from snuba.consumers.consumer import ( BytesInsertBatch, + CommitLogConfig, + CommitLogHeartbeatState, InsertBatchWriter, LatencyRecorder, + MultistorageCollector, ProcessedMessageBatchWriter, ReplacementBatchWriter, build_batch_writer, @@ -239,3 +242,473 @@ def test_latency_recorder() -> None: assert recorder.max_ms == 1200.0 # (2.7 / 3) * 1000 == 900 assert recorder.avg_ms == 900.0 + + +def test_commit_log_heartbeat_for_idle_partitions() -> None: + """ + Verify that when a batch is flushed containing messages only from + partitions 0 and 1, but the consumer is assigned partitions 0-3, + commit log entries are still produced for partitions 2 and 3 using + their last known offsets. + """ + commit_log_producer = Mock() + commit_log_topic = Topic("snuba-commit-log") + commit_log_config = CommitLogConfig( + producer=commit_log_producer, + topic=commit_log_topic, + group_id="test-group", + ) + + heartbeat_state = CommitLogHeartbeatState() + # Simulate assigned partitions 0-3 + assigned = {Partition(Topic("events"), i): 0 for i in range(4)} + heartbeat_state.update_partitions(assigned) + + writer = Mock() + metrics = TestingMetricsBackend() + + # --- Batch 1: messages on all 4 partitions (establishes last known offsets) --- + batch_writer_1 = ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, + ) + + now = datetime.now() + for i in range(4): + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), i), + 100 + i, # offset + now, + ) + ) + batch_writer_1.submit(msg.replace(BytesInsertBatch([b'{"col1": "val1"}'], now))) + + batch_writer_1.close() + + # After batch 1: should have 4 commit log produces (one per partition) + # plus heartbeats - but all partitions had data, so no heartbeats + assert commit_log_producer.produce.call_count == 4 + + commit_log_producer.reset_mock() + + # --- Batch 2: messages only on partitions 0 and 1 --- + batch_writer_2 = ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, + ) + + for i in range(2): # Only partitions 0 and 1 + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), i), + 200 + i, # offset + now, + ) + ) + batch_writer_2.submit(msg.replace(BytesInsertBatch([b'{"col1": "val1"}'], now))) + + batch_writer_2.close() + + # Should have 4 total produces: + # - 2 for partitions 0 and 1 (actual data) + # - 2 heartbeats for partitions 2 and 3 (idle) + assert commit_log_producer.produce.call_count == 4 + + # Verify the produce calls include the right partitions + produced_keys = [] + for c in commit_log_producer.produce.call_args_list: + # The key encodes the partition info + produced_keys.append( + c.kwargs.get("key") or c.args[1] if len(c.args) > 1 else c.kwargs.get("key") + ) + + # All 4 partitions should have commit log entries + # (keys are in format "topic:partition:group" encoded bytes) + assert commit_log_producer.produce.call_count == 4 + + +def test_commit_log_heartbeat_no_heartbeat_without_prior_offset() -> None: + """ + Verify that no heartbeat is produced for a partition that has never + had any messages (no last known offset). + """ + commit_log_producer = Mock() + commit_log_topic = Topic("snuba-commit-log") + commit_log_config = CommitLogConfig( + producer=commit_log_producer, + topic=commit_log_topic, + group_id="test-group", + ) + + heartbeat_state = CommitLogHeartbeatState() + assigned = {Partition(Topic("events"), i): 0 for i in range(4)} + heartbeat_state.update_partitions(assigned) + + writer = Mock() + metrics = TestingMetricsBackend() + + # Batch with messages only on partition 0 (no prior offsets for 1-3) + batch_writer = ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, + ) + + now = datetime.now() + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), 0), + 100, + now, + ) + ) + batch_writer.submit(msg.replace(BytesInsertBatch([b'{"col1": "val1"}'], now))) + batch_writer.close() + + # Only 1 produce: partition 0 has data, partitions 1-3 have no prior offset + assert commit_log_producer.produce.call_count == 1 + + +def test_commit_log_heartbeat_state_rebalance() -> None: + """ + Verify that when partitions are reassigned, old partition offsets are cleaned up. + """ + heartbeat_state = CommitLogHeartbeatState() + + # Initially assigned partitions 0-3 + assigned = {Partition(Topic("events"), i): 0 for i in range(4)} + heartbeat_state.update_partitions(assigned) + + # Simulate some last produced offsets + now = datetime.now() + for i in range(4): + heartbeat_state.last_produced_offsets[Partition(Topic("events"), i)] = (100 + i, now) + + assert len(heartbeat_state.last_produced_offsets) == 4 + + # Rebalance: now only assigned partitions 0-1 + new_assigned = {Partition(Topic("events"), i): 0 for i in range(2)} + heartbeat_state.update_partitions(new_assigned) + + # Offsets for partitions 2 and 3 should be cleaned up + assert len(heartbeat_state.last_produced_offsets) == 2 + assert Partition(Topic("events"), 0) in heartbeat_state.last_produced_offsets + assert Partition(Topic("events"), 1) in heartbeat_state.last_produced_offsets + assert Partition(Topic("events"), 2) not in heartbeat_state.last_produced_offsets + assert Partition(Topic("events"), 3) not in heartbeat_state.last_produced_offsets + + +def test_commit_log_heartbeat_offset_evolution() -> None: + """ + Verify that heartbeat offsets evolve as new batches arrive. After batch 2 + with partition 0 at offset 500, the heartbeat for partition 0 in batch 3 + (where partition 0 is idle) should use offset 500, not the old offset 100. + """ + commit_log_producer = Mock() + commit_log_config = CommitLogConfig( + producer=commit_log_producer, + topic=Topic("snuba-commit-log"), + group_id="test-group", + ) + + heartbeat_state = CommitLogHeartbeatState() + assigned = {Partition(Topic("events"), i): 0 for i in range(2)} + heartbeat_state.update_partitions(assigned) + + writer = Mock() + metrics = TestingMetricsBackend() + now = datetime.now() + + def make_batch_writer() -> ProcessedMessageBatchWriter: + return ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, + ) + + def submit_message(bw: ProcessedMessageBatchWriter, partition: int, offset: int) -> None: + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), partition), + offset, + now, + ) + ) + bw.submit(msg.replace(BytesInsertBatch([b'{"col1": "val1"}'], now))) + + # Batch 1: both partitions, establishing initial offsets + bw1 = make_batch_writer() + submit_message(bw1, 0, 100) + submit_message(bw1, 1, 200) + bw1.close() + commit_log_producer.reset_mock() + + # Batch 2: only partition 0 at a higher offset + bw2 = make_batch_writer() + submit_message(bw2, 0, 500) + bw2.close() + + # 1 data produce (partition 0) + 1 heartbeat (partition 1 at offset 200) + assert commit_log_producer.produce.call_count == 2 + commit_log_producer.reset_mock() + + # Batch 3: only partition 1 at a higher offset + bw3 = make_batch_writer() + submit_message(bw3, 1, 700) + bw3.close() + + # 1 data produce (partition 1) + 1 heartbeat (partition 0) + assert commit_log_producer.produce.call_count == 2 + + # The heartbeat for partition 0 should use next_offset 501 (from batch 2), + # not the original 101 (from batch 1). Check the shared state. + # Note: BrokerValue stores next_offset = offset + 1 + assert heartbeat_state.last_produced_offsets[Partition(Topic("events"), 0)] == ( + 501, + now, + ) + assert heartbeat_state.last_produced_offsets[Partition(Topic("events"), 1)] == ( + 701, + now, + ) + + +def test_commit_log_heartbeat_no_state_means_no_heartbeats() -> None: + """ + Verify that when heartbeat_state is None (e.g., commit log disabled or DLQ + consumer), no heartbeat logic is triggered — only data offsets are produced. + """ + commit_log_producer = Mock() + commit_log_config = CommitLogConfig( + producer=commit_log_producer, + topic=Topic("snuba-commit-log"), + group_id="test-group", + ) + + writer = Mock() + metrics = TestingMetricsBackend() + now = datetime.now() + + # No heartbeat_state + bw = ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + commit_log_config=commit_log_config, + heartbeat_state=None, + ) + + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), 0), + 100, + now, + ) + ) + bw.submit(msg.replace(BytesInsertBatch([b'{"col1": "val1"}'], now))) + bw.close() + + # Only the 1 data produce — no heartbeats + assert commit_log_producer.produce.call_count == 1 + + +def test_multistorage_collector_heartbeat_for_idle_partitions() -> None: + """ + Verify MultistorageCollector produces heartbeat commit log entries + for idle assigned partitions. + """ + commit_log_producer = Mock() + commit_log_config = CommitLogConfig( + producer=commit_log_producer, + topic=Topic("snuba-commit-log"), + group_id="test-group", + ) + + heartbeat_state = CommitLogHeartbeatState() + assigned = {Partition(Topic("events"), i): 0 for i in range(4)} + heartbeat_state.update_partitions(assigned) + + # Create mock ProcessedMessageBatchWriter steps for a fake storage key + storage_key = StorageKey("errors") + writer = Mock() + metrics = TestingMetricsBackend() + step = ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + ) + + now = datetime.now() + + # Batch 1: establish offsets for all 4 partitions via the collector + collector1 = MultistorageCollector( + steps={storage_key: step}, + commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, + ) + + for i in range(4): + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), i), + 100 + i, + now, + ) + ) + collector1.submit( + msg.replace([(storage_key, BytesInsertBatch([b'{"col1": "val1"}'], now))]) + ) + collector1.close() + + # 4 produces for data partitions, no heartbeats + assert commit_log_producer.produce.call_count == 4 + commit_log_producer.reset_mock() + + # Batch 2: only partitions 0 and 1 + step2 = ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + ) + + collector2 = MultistorageCollector( + steps={storage_key: step2}, + commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, + ) + + for i in range(2): + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), i), + 200 + i, + now, + ) + ) + collector2.submit( + msg.replace([(storage_key, BytesInsertBatch([b'{"col1": "val1"}'], now))]) + ) + collector2.close() + + # 2 data produces + 2 heartbeats for idle partitions 2 and 3 + assert commit_log_producer.produce.call_count == 4 + + +def test_build_batch_writer_passes_heartbeat_state() -> None: + """ + Verify that build_batch_writer correctly passes heartbeat_state + to each ProcessedMessageBatchWriter it creates. + """ + commit_log_producer = Mock() + commit_log_config = CommitLogConfig( + producer=commit_log_producer, + topic=Topic("snuba-commit-log"), + group_id="test-group", + ) + + heartbeat_state = CommitLogHeartbeatState() + assigned = {Partition(Topic("events"), i): 0 for i in range(2)} + heartbeat_state.update_partitions(assigned) + + writer_mock = Mock() + table_writer = Mock() + table_writer.get_batch_writer.return_value = writer_mock + + metrics = TestingMetricsBackend() + + factory = build_batch_writer( + table_writer, + metrics=metrics, + commit_log_config=commit_log_config, + heartbeat_state=heartbeat_state, + ) + + # Create two batch writers from the factory — they should share heartbeat_state + bw1 = factory() + bw2 = factory() + + # Both should be ProcessedMessageBatchWriter instances + assert isinstance(bw1, ProcessedMessageBatchWriter) + assert isinstance(bw2, ProcessedMessageBatchWriter) + + now = datetime.now() + + # Submit a message to bw1 to establish an offset for partition 0 + msg = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), 0), + 100, + now, + ) + ) + bw1.submit(msg.replace(BytesInsertBatch([b'{"col1": "val1"}'], now))) + bw1.close() + commit_log_producer.reset_mock() + + # bw2 gets a message only on partition 1 — partition 0 should get a heartbeat + # because the shared state knows about partition 0's offset from bw1 + msg2 = Message( + BrokerValue( + KafkaPayload(None, b'{"col1": "val1"}', []), + Partition(Topic("events"), 1), + 200, + now, + ) + ) + bw2.submit(msg2.replace(BytesInsertBatch([b'{"col1": "val1"}'], now))) + bw2.close() + + # 1 data produce (partition 1) + 1 heartbeat (partition 0) + assert commit_log_producer.produce.call_count == 2 + + +def test_strategy_factory_updates_heartbeat_state_on_rebalance() -> None: + """ + Verify that KafkaConsumerStrategyFactory.create_with_partitions() + updates the heartbeat state with the new partition assignments. + """ + heartbeat_state = CommitLogHeartbeatState() + + processor = Mock() + processor.process_message.return_value = None + + writer = Mock() + metrics = TestingMetricsBackend() + + def write_step() -> ProcessedMessageBatchWriter: + return ProcessedMessageBatchWriter( + insert_batch_writer=InsertBatchWriter(writer, MetricsWrapper(metrics, "insertions")), + heartbeat_state=heartbeat_state, + ) + + factory = KafkaConsumerStrategyFactory( + None, + functools.partial(process_message, processor, "consumer_group", SnubaTopic.EVENTS, True), + write_step, + max_batch_size=10, + max_batch_time=60, + max_insert_batch_size=None, + max_insert_batch_time=None, + processes=None, + input_block_size=None, + output_block_size=None, + health_check_file=None, + metrics_tags={}, + heartbeat_state=heartbeat_state, + ) + + # Initial assignment: partitions 0-3 + partitions = {Partition(Topic("events"), i): 0 for i in range(4)} + commit = Mock() + factory.create_with_partitions(commit, partitions) + assert heartbeat_state.assigned_partitions == set(partitions.keys()) + + # Rebalance: now only partitions 0-1 + new_partitions = {Partition(Topic("events"), i): 0 for i in range(2)} + factory.create_with_partitions(commit, new_partitions) + assert heartbeat_state.assigned_partitions == set(new_partitions.keys())