@@ -714,6 +714,82 @@ async def delayed_stream_response(*args, **kwargs):
714714 assert model .first_turn_args is not None , "Model should have been called in parallel mode"
715715
716716
717+ @pytest .mark .asyncio
718+ async def test_parallel_guardrail_trip_with_slow_cancel_sibling_stops_streaming_turn ():
719+ tool_was_executed = False
720+ model_started = asyncio .Event ()
721+ guardrail_tripped = asyncio .Event ()
722+ slow_cancel_started = asyncio .Event ()
723+ slow_cancel_finished = asyncio .Event ()
724+
725+ @function_tool
726+ def dangerous_tool () -> str :
727+ nonlocal tool_was_executed
728+ tool_was_executed = True
729+ return "tool_executed"
730+
731+ @input_guardrail (run_in_parallel = True )
732+ async def tripwire_before_tool_execution (
733+ ctx : RunContextWrapper [Any ], agent : Agent [Any ], input : str | list [TResponseInputItem ]
734+ ) -> GuardrailFunctionOutput :
735+ await asyncio .wait_for (model_started .wait (), timeout = 1 )
736+ guardrail_tripped .set ()
737+ return GuardrailFunctionOutput (
738+ output_info = "parallel_trip_before_tool_execution_with_slow_cancel" ,
739+ tripwire_triggered = True ,
740+ )
741+
742+ @input_guardrail (run_in_parallel = True )
743+ async def slow_to_cancel_guardrail (
744+ ctx : RunContextWrapper [Any ], agent : Agent [Any ], input : str | list [TResponseInputItem ]
745+ ) -> GuardrailFunctionOutput :
746+ try :
747+ await asyncio .Event ().wait ()
748+ return GuardrailFunctionOutput (
749+ output_info = "slow_to_cancel_guardrail_completed" ,
750+ tripwire_triggered = False ,
751+ )
752+ except asyncio .CancelledError :
753+ slow_cancel_started .set ()
754+ await asyncio .sleep (SHORT_DELAY )
755+ slow_cancel_finished .set ()
756+ raise
757+
758+ model = FakeModel ()
759+ original_stream_response = model .stream_response
760+
761+ async def delayed_stream_response (* args , ** kwargs ):
762+ model_started .set ()
763+ await asyncio .wait_for (guardrail_tripped .wait (), timeout = 1 )
764+ await asyncio .wait_for (slow_cancel_started .wait (), timeout = 1 )
765+ async for event in original_stream_response (* args , ** kwargs ):
766+ yield event
767+
768+ agent = Agent (
769+ name = "streaming_guardrail_slow_cancel_agent" ,
770+ instructions = "Call the dangerous_tool immediately" ,
771+ tools = [dangerous_tool ],
772+ input_guardrails = [tripwire_before_tool_execution , slow_to_cancel_guardrail ],
773+ model = model ,
774+ )
775+ model .set_next_output ([get_function_tool_call ("dangerous_tool" , arguments = "{}" )])
776+ model .set_next_output ([get_text_message ("done" )])
777+
778+ with patch .object (model , "stream_response" , side_effect = delayed_stream_response ):
779+ result = Runner .run_streamed (agent , "trigger guardrail" )
780+
781+ with pytest .raises (InputGuardrailTripwireTriggered ):
782+ async for _event in result .stream_events ():
783+ pass
784+
785+ assert model_started .is_set () is True
786+ assert guardrail_tripped .is_set () is True
787+ assert slow_cancel_started .is_set () is True
788+ assert slow_cancel_finished .is_set () is True
789+ assert tool_was_executed is False
790+ assert model .first_turn_args is not None , "Model should have been called in parallel mode"
791+
792+
717793@pytest .mark .asyncio
718794async def test_blocking_guardrail_prevents_tool_execution ():
719795 tool_was_executed = False
0 commit comments