diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index 6d1bf0c2937f5..b265c3fc276f7 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -972,7 +972,13 @@ boolean runLoop() { // check if any active task got corrupted. We will trigger a rebalance in that case. // once the task corruptions have been handled final boolean enforceRebalance = taskManager.handleCorruption(e.corruptedTasks()); - if (enforceRebalance && eosEnabled) { + // The corrupted tasks have already been recovered locally (closed dirty, revived and + // scheduled for re-initialization with their input offsets reset). Under the classic + // protocol we additionally enforce a rebalance so the assignor can temporarily move the + // task to a standby while this client restores its state from scratch (KAFKA-12486). + // Under the Streams group protocol (KIP-1071) assignment and warmup are driven by the + // broker, so the client-side enforceRebalance is unsupported (it would only log a warning). + if (enforceRebalance && eosEnabled && streamsRebalanceData.isEmpty()) { log.info("Active task(s) got corrupted. Triggering a rebalance."); mainConsumer.enforceRebalance("Active tasks corrupted"); } @@ -1077,7 +1083,14 @@ public void maybeSendShutdown() { "All clients in this app will now begin to shutdown"); } } - mainConsumer.enforceRebalance("Shutdown requested"); + // Under the classic protocol the shutdown request is propagated to the rest of the group + // by the assignor during a rebalance, so we need to enforce one. Under the Streams group + // protocol (KIP-1071) the request is propagated through the group heartbeat (see + // sendShutdownRequest), and enforceRebalance is not supported by the consumer (it would + // only log a warning), so we skip it. + if (streamsRebalanceData.isEmpty()) { + mainConsumer.enforceRebalance("Shutdown requested"); + } } } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 01aa4279c431d..fb8bb26c0be12 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -2788,6 +2788,81 @@ void runOnceWithoutProcessingThreads() { thread.runLoop(); verify(consumer).subscribe((Collection) any(), any()); + verify(consumer).enforceRebalance("Active tasks corrupted"); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + @SuppressWarnings("unchecked") + public void shouldNotEnforceRebalanceOnTaskCorruptedExceptionUnderStreamsProtocol(final boolean processingThreadsEnabled) { + final StreamsConfig config = new StreamsConfig(configProps(true, processingThreadsEnabled)); + final TaskManager taskManager = mock(TaskManager.class); + // The Streams group protocol requires the main consumer to be an AsyncKafkaConsumer (see subscribeConsumer). + final Consumer consumer = mock(AsyncKafkaConsumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + when(consumer.groupMetadata()).thenReturn(consumerGroupMetadata); + when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty()); + + final TaskId taskId1 = new TaskId(0, 0); + final Set corruptedTasks = singleton(taskId1); + when(taskManager.handleCorruption(corruptedTasks)).thenReturn(true); + + final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData( + UUID.randomUUID(), + Optional.empty(), + Optional.empty(), + Map.of(), + Map.of() + ); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + changelogReader, + null, + taskManager, + null, + streamsMetrics, + topologyMetadata, + PROCESS_ID, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + new LinkedList<>(), + null, + HANDLER, + null, + Optional.of(streamsRebalanceData), + null, + null + ) { + @Override + void runOnceWithProcessingThreads() { + setState(State.PENDING_SHUTDOWN); + throw new TaskCorruptedException(corruptedTasks); + } + @Override + void runOnceWithoutProcessingThreads() { + setState(State.PENDING_SHUTDOWN); + throw new TaskCorruptedException(corruptedTasks); + } + }.updateThreadMetadata(adminClientId(CLIENT_ID)); + + thread.setState(StreamThread.State.STARTING); + thread.runLoop(); + + // Under the Streams group protocol the corrupted task is recovered locally and any HA failover + // is driven by the broker, so the client must not enforce a rebalance (it is unsupported). + verify(taskManager).handleCorruption(corruptedTasks); + verify(consumer, never()).enforceRebalance(anyString()); } @ParameterizedTest @@ -3731,6 +3806,105 @@ public void testStreamsProtocolRunOnceWithoutProcessingThreads() { verify(shutdownErrorHook).run(); } + @Test + public void shouldNotEnforceRebalanceOnShutdownRequestUnderStreamsProtocol() { + final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class); + when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty()); + when(mainConsumer.groupMetadata()).thenReturn(consumerGroupMetadata); + final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData( + UUID.randomUUID(), + Optional.empty(), + Optional.empty(), + Map.of(), + Map.of() + ); + + final Properties props = configProps(false, false); + final StreamsMetadataState streamsMetadataState = new StreamsMetadataState( + new TopologyMetadata(internalTopologyBuilder, new StreamsConfig(props)), + StreamsMetadataState.UNKNOWN_HOST, + new LogContext(String.format("stream-client [%s] ", CLIENT_ID)) + ); + final StreamsConfig config = new StreamsConfig(props); + thread = new StreamThread( + new MockTime(1), + config, + null, + mainConsumer, + consumer, + changelogReader, + null, + mock(TaskManager.class), + null, + new StreamsMetricsImpl(metrics, CLIENT_ID, mockTime), + new TopologyMetadata(internalTopologyBuilder, config), + PROCESS_ID, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + new LinkedList<>(), + mock(Runnable.class), + HANDLER, + null, + Optional.of(streamsRebalanceData), + streamsMetadataState, + null + ).updateThreadMetadata(adminClientId(CLIENT_ID)); + + thread.sendShutdownRequest(); + thread.maybeSendShutdown(); + + // Under the Streams group protocol the shutdown request is propagated through the group + // heartbeat, not via a client-enforced rebalance (which the consumer does not support). + assertTrue(streamsRebalanceData.shutdownRequested()); + verify(mainConsumer, never()).enforceRebalance(anyString()); + } + + @Test + public void shouldEnforceRebalanceOnShutdownRequestUnderClassicProtocol() { + final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class); + when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty()); + when(mainConsumer.groupMetadata()).thenReturn(consumerGroupMetadata); + final Properties props = configProps(false, false); + final StreamsMetadataState streamsMetadataState = new StreamsMetadataState( + new TopologyMetadata(internalTopologyBuilder, new StreamsConfig(props)), + StreamsMetadataState.UNKNOWN_HOST, + new LogContext(String.format("stream-client [%s] ", CLIENT_ID)) + ); + final StreamsConfig config = new StreamsConfig(props); + thread = new StreamThread( + new MockTime(1), + config, + null, + mainConsumer, + consumer, + changelogReader, + null, + mock(TaskManager.class), + null, + new StreamsMetricsImpl(metrics, CLIENT_ID, mockTime), + new TopologyMetadata(internalTopologyBuilder, config), + PROCESS_ID, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + new LinkedList<>(), + mock(Runnable.class), + HANDLER, + null, + Optional.empty(), + streamsMetadataState, + null + ).updateThreadMetadata(adminClientId(CLIENT_ID)); + + thread.sendShutdownRequest(); + thread.maybeSendShutdown(); + + verify(mainConsumer).enforceRebalance("Shutdown requested"); + } + @Test public void testStreamsProtocolRunOnceWithoutProcessingThreadsMissingSourceTopic() { final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class);