Skip to content

Commit 623000c

Browse files
committed
fix review comments
1 parent 9abb7fe commit 623000c

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

src/agents/run_internal/guardrails.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ async def run_input_guardrails_with_queue(
7171
for done in asyncio.as_completed(guardrail_tasks):
7272
result = await done
7373
if result.output.tripwire_triggered:
74+
streamed_result._triggered_input_guardrail_result = result
7475
for t in guardrail_tasks:
7576
t.cancel()
7677
await asyncio.gather(*guardrail_tasks, return_exceptions=True)
@@ -84,7 +85,6 @@ async def run_input_guardrails_with_queue(
8485
},
8586
),
8687
)
87-
streamed_result._triggered_input_guardrail_result = result
8888
queue.put_nowait(result)
8989
guardrail_results.append(result)
9090
break

tests/test_guardrails.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,82 @@ async def delayed_stream_response(*args, **kwargs):
714714
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
715715

716716

717+
@pytest.mark.asyncio
718+
async def test_parallel_guardrail_trip_with_slow_cancel_sibling_stops_streaming_turn():
719+
tool_was_executed = False
720+
model_started = asyncio.Event()
721+
guardrail_tripped = asyncio.Event()
722+
slow_cancel_started = asyncio.Event()
723+
slow_cancel_finished = asyncio.Event()
724+
725+
@function_tool
726+
def dangerous_tool() -> str:
727+
nonlocal tool_was_executed
728+
tool_was_executed = True
729+
return "tool_executed"
730+
731+
@input_guardrail(run_in_parallel=True)
732+
async def tripwire_before_tool_execution(
733+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
734+
) -> GuardrailFunctionOutput:
735+
await asyncio.wait_for(model_started.wait(), timeout=1)
736+
guardrail_tripped.set()
737+
return GuardrailFunctionOutput(
738+
output_info="parallel_trip_before_tool_execution_with_slow_cancel",
739+
tripwire_triggered=True,
740+
)
741+
742+
@input_guardrail(run_in_parallel=True)
743+
async def slow_to_cancel_guardrail(
744+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
745+
) -> GuardrailFunctionOutput:
746+
try:
747+
await asyncio.Event().wait()
748+
return GuardrailFunctionOutput(
749+
output_info="slow_to_cancel_guardrail_completed",
750+
tripwire_triggered=False,
751+
)
752+
except asyncio.CancelledError:
753+
slow_cancel_started.set()
754+
await asyncio.sleep(SHORT_DELAY)
755+
slow_cancel_finished.set()
756+
raise
757+
758+
model = FakeModel()
759+
original_stream_response = model.stream_response
760+
761+
async def delayed_stream_response(*args, **kwargs):
762+
model_started.set()
763+
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
764+
await asyncio.wait_for(slow_cancel_started.wait(), timeout=1)
765+
async for event in original_stream_response(*args, **kwargs):
766+
yield event
767+
768+
agent = Agent(
769+
name="streaming_guardrail_slow_cancel_agent",
770+
instructions="Call the dangerous_tool immediately",
771+
tools=[dangerous_tool],
772+
input_guardrails=[tripwire_before_tool_execution, slow_to_cancel_guardrail],
773+
model=model,
774+
)
775+
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
776+
model.set_next_output([get_text_message("done")])
777+
778+
with patch.object(model, "stream_response", side_effect=delayed_stream_response):
779+
result = Runner.run_streamed(agent, "trigger guardrail")
780+
781+
with pytest.raises(InputGuardrailTripwireTriggered):
782+
async for _event in result.stream_events():
783+
pass
784+
785+
assert model_started.is_set() is True
786+
assert guardrail_tripped.is_set() is True
787+
assert slow_cancel_started.is_set() is True
788+
assert slow_cancel_finished.is_set() is True
789+
assert tool_was_executed is False
790+
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
791+
792+
717793
@pytest.mark.asyncio
718794
async def test_blocking_guardrail_prevents_tool_execution():
719795
tool_was_executed = False

0 commit comments

Comments
 (0)