Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ class RunResultStreaming(RunResultBase):
# Store the asyncio tasks that we're waiting on
run_loop_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_triggered_input_guardrail_result: InputGuardrailResult | None = field(default=None, repr=False)
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_stored_exception: Exception | None = field(default=None, repr=False)
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
Expand Down
1 change: 1 addition & 0 deletions src/agents/run_internal/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ async def run_input_guardrails_with_queue(
for done in asyncio.as_completed(guardrail_tasks):
result = await done
if result.output.tripwire_triggered:
streamed_result._triggered_input_guardrail_result = result
for t in guardrail_tasks:
t.cancel()
await asyncio.gather(*guardrail_tasks, return_exceptions=True)
Expand Down
19 changes: 19 additions & 0 deletions src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,24 @@ async def run_single_turn_streamed(
reasoning_item_id_policy: ReasoningItemIdPolicy | None = None,
) -> SingleStepResult:
"""Run a single streamed turn and emit events as results arrive."""

async def raise_if_input_guardrail_tripwire_known() -> None:
tripwire_result = streamed_result._triggered_input_guardrail_result
if tripwire_result is not None:
raise InputGuardrailTripwireTriggered(tripwire_result)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Append tripwire result before raising streamed guardrail exception

raise_if_input_guardrail_tripwire_known raises immediately when _triggered_input_guardrail_result is set, but run_input_guardrails_with_queue does not append that result to streamed_result.input_guardrail_results until after await asyncio.gather(...) finishes sibling cancellations. With a slow-cancel sibling, callers receive InputGuardrailTripwireTriggered whose run_data.input_guardrail_results is empty, losing guardrail context.

Useful? React with 👍 / 👎.


task = streamed_result._input_guardrails_task
if task is None or not task.done():
return

guardrail_exception = task.exception()
if guardrail_exception is not None:
raise guardrail_exception

tripwire_result = streamed_result._triggered_input_guardrail_result
if tripwire_result is not None:
raise InputGuardrailTripwireTriggered(tripwire_result)

emitted_tool_call_ids: set[str] = set()
emitted_reasoning_item_ids: set[str] = set()
emitted_tool_search_fingerprints: set[str] = set()
Expand Down Expand Up @@ -1450,6 +1468,7 @@ async def rewind_model_request() -> None:
run_config=run_config,
tool_use_tracker=tool_use_tracker,
event_queue=streamed_result._event_queue,
before_side_effects=raise_if_input_guardrail_tripwire_known,
)

items_to_filter = session_items_for_turn(single_step_result)
Expand Down
4 changes: 4 additions & 0 deletions src/agents/run_internal/turn_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,7 @@ async def get_single_step_result_from_response(
run_config: RunConfig,
tool_use_tracker,
event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None,
before_side_effects: Callable[[], Awaitable[None]] | None = None,
) -> SingleStepResult:
processed_response = process_model_response(
agent=agent,
Expand All @@ -1706,6 +1707,9 @@ async def get_single_step_result_from_response(
existing_items=pre_step_items,
)

if before_side_effects is not None:
await before_side_effects()

tool_use_tracker.record_processed_response(agent, processed_response)

if event_queue is not None and processed_response.new_items:
Expand Down
132 changes: 132 additions & 0 deletions tests/test_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,138 @@ async def slow_parallel_check(
assert model.first_turn_args is not None, "Model should have been called in parallel mode"


@pytest.mark.asyncio
async def test_parallel_guardrail_trip_before_tool_execution_stops_streaming_turn():
tool_was_executed = False
model_started = asyncio.Event()
guardrail_tripped = asyncio.Event()

@function_tool
def dangerous_tool() -> str:
nonlocal tool_was_executed
tool_was_executed = True
return "tool_executed"

@input_guardrail(run_in_parallel=True)
async def tripwire_before_tool_execution(
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
await asyncio.wait_for(model_started.wait(), timeout=1)
guardrail_tripped.set()
return GuardrailFunctionOutput(
output_info="parallel_trip_before_tool_execution",
tripwire_triggered=True,
)

model = FakeModel()
original_stream_response = model.stream_response

async def delayed_stream_response(*args, **kwargs):
model_started.set()
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
await asyncio.sleep(SHORT_DELAY)
async for event in original_stream_response(*args, **kwargs):
yield event

agent = Agent(
name="streaming_guardrail_hardening_agent",
instructions="Call the dangerous_tool immediately",
tools=[dangerous_tool],
input_guardrails=[tripwire_before_tool_execution],
model=model,
)
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
model.set_next_output([get_text_message("done")])

with patch.object(model, "stream_response", side_effect=delayed_stream_response):
result = Runner.run_streamed(agent, "trigger guardrail")

with pytest.raises(InputGuardrailTripwireTriggered):
async for _event in result.stream_events():
pass

assert model_started.is_set() is True
assert guardrail_tripped.is_set() is True
assert tool_was_executed is False
assert model.first_turn_args is not None, "Model should have been called in parallel mode"


@pytest.mark.asyncio
async def test_parallel_guardrail_trip_with_slow_cancel_sibling_stops_streaming_turn():
tool_was_executed = False
model_started = asyncio.Event()
guardrail_tripped = asyncio.Event()
slow_cancel_started = asyncio.Event()
slow_cancel_finished = asyncio.Event()

@function_tool
def dangerous_tool() -> str:
nonlocal tool_was_executed
tool_was_executed = True
return "tool_executed"

@input_guardrail(run_in_parallel=True)
async def tripwire_before_tool_execution(
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
await asyncio.wait_for(model_started.wait(), timeout=1)
guardrail_tripped.set()
return GuardrailFunctionOutput(
output_info="parallel_trip_before_tool_execution_with_slow_cancel",
tripwire_triggered=True,
)

@input_guardrail(run_in_parallel=True)
async def slow_to_cancel_guardrail(
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
try:
await asyncio.Event().wait()
return GuardrailFunctionOutput(
output_info="slow_to_cancel_guardrail_completed",
tripwire_triggered=False,
)
except asyncio.CancelledError:
slow_cancel_started.set()
await asyncio.sleep(SHORT_DELAY)
slow_cancel_finished.set()
raise

model = FakeModel()
original_stream_response = model.stream_response

async def delayed_stream_response(*args, **kwargs):
model_started.set()
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
await asyncio.wait_for(slow_cancel_started.wait(), timeout=1)
async for event in original_stream_response(*args, **kwargs):
yield event

agent = Agent(
name="streaming_guardrail_slow_cancel_agent",
instructions="Call the dangerous_tool immediately",
tools=[dangerous_tool],
input_guardrails=[tripwire_before_tool_execution, slow_to_cancel_guardrail],
model=model,
)
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
model.set_next_output([get_text_message("done")])

with patch.object(model, "stream_response", side_effect=delayed_stream_response):
result = Runner.run_streamed(agent, "trigger guardrail")

with pytest.raises(InputGuardrailTripwireTriggered):
async for _event in result.stream_events():
pass

assert model_started.is_set() is True
assert guardrail_tripped.is_set() is True
assert slow_cancel_started.is_set() is True
assert slow_cancel_finished.is_set() is True
assert tool_was_executed is False
assert model.first_turn_args is not None, "Model should have been called in parallel mode"


@pytest.mark.asyncio
async def test_blocking_guardrail_prevents_tool_execution():
tool_was_executed = False
Expand Down
Loading