Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/modelplane/evaluator/dag.py
Comment thread
superdosh marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_traced(self, ctx: EvalContext) -> tuple[DAGOutput, set[tuple[str, str]]
for node_name in self._ordered:
if node_name not in reachable:
continue
run_ctx = ctx.with_parent_outputs(
ctx = ctx.with_parent_outputs(
{
pred: node_outputs[pred]
for pred in self._predecessors[node_name]
Expand All @@ -190,14 +190,14 @@ def _run_traced(self, ctx: EvalContext) -> tuple[DAGOutput, set[tuple[str, str]]
)
node = self._nodes[node_name]
if isinstance(node, CacheableNodeMixin):
key = node.cache_key(run_ctx)
key = node.cache_key(ctx)
if key in self._node_caches[node.name]:
output = self._node_caches[node.name][key]
else:
output = node.run(run_ctx)
output = node.run(ctx)
self._node_caches[node.name][key] = output
else:
output = node.run(run_ctx)
output = node.run(ctx)
node_outputs[node_name] = output
total_cost += output.realized_cost
if isinstance(output.value, Verdict):
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/evaluator/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def output_tokens(self, ctx: EvalContext) -> int:
return context_token_count(ctx)


class NoOpEnricher(Enricher):
"""Passes context through without changing it."""

def run(self, ctx: EvalContext) -> NodeOutput:
return self.build_output(None, ctx)


class FixedScorer(Enricher):
"""Returns a fixed float score regardless of context."""

Expand Down
91 changes: 89 additions & 2 deletions tests/unit/evaluator/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
from modelplane.evaluator.safety import Safety

from .conftest import skip_in_ci
from .mocks import AlwaysTrueCacheable

from .mocks import (
AlwaysSafe,
AlwaysTrue,
AlwaysTrueCacheable,
LowerCaser,
LowerCaseScorer,
NoOpEnricher,
ThresholdArbiter,
UpperCaser,
)

def test_dag_outputs(simple_dag):
assert simple_dag.verdict_type == Safety
Expand Down Expand Up @@ -182,6 +190,85 @@ def test_dag_run(simple_dag, sample_ctx):
assert dag_output.verdict.name == "UNSAFE"


def test_dag_passes_updated_context_to_downstream_nodes():
ctx = EvalContext(prompt="x", response="HELLO")
dag = (
Composer("ctx_update", verdict_type=Safety)
.add_node(
AlwaysTrue(
name="always_true",
routes_true=["lower_caser"],
routes_false=["always_safe"],
)
)
.add_node(AlwaysSafe(name="always_safe"))
.add_node(LowerCaser(name="lower_caser", routes=["noop"]))
.add_node(NoOpEnricher(name="noop", routes=["lower_scorer"]))
.add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"]))
.add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5))
)
dag_output = dag.run(ctx)
assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello"
# Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser.
assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0)


def test_dag_updated_context_not_passed_to_parallel_nodes():
# noop and lower caser are parallel nodes. noop should not see the updated context from lower_caser.
ctx = EvalContext(prompt="x", response="HELLO")
dag = (
Composer("ctx_update", verdict_type=Safety)
.add_node(
AlwaysTrue(
name="always_true",
routes_true=["lower_caser", "noop"],
routes_false=["always_safe"],
)
)
.add_node(AlwaysSafe(name="always_safe"))
.add_node(LowerCaser(name="lower_caser", routes=["lower_scorer"]))
.add_node(NoOpEnricher(name="noop", routes=["lower_scorer"]))
.add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"]))
.add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5))
)
dag_output = dag.run(ctx)

assert dag_output.node_outputs["lower_caser"].original_ctx.response == "HELLO"
assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello"

assert dag_output.node_outputs["noop"].original_ctx.response == "HELLO"
assert dag_output.node_outputs["noop"].updated_ctx is None

assert dag_output.node_outputs["lower_scorer"].original_ctx.response == "hello"
# Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser.
assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0)


def test_dag_parallel_nodes_different_updated_contexts_raises_error():
# upper caser and lower caser are parallel nodes, they update the dontext differently which should raise an error.
ctx = EvalContext(prompt="x", response="HELLO")
dag = (
Composer("ctx_update", verdict_type=Safety)
.add_node(
AlwaysTrue(
name="always_true",
routes_true=["lower_caser", "upper_caser"],
routes_false=["always_safe"],
)
)
.add_node(AlwaysSafe(name="always_safe"))
.add_node(LowerCaser(name="lower_caser", routes=["lower_scorer"]))
.add_node(UpperCaser(name="upper_caser", routes=["lower_scorer"]))
.add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"]))
.add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5))
)
with pytest.raises(
ValueError,
match="all parent outputs must have the same updated context",
):
dag.run(ctx)


def test_dag_run_with_dataframe(simple_dag, tmp_path):
# "hello world" (space lowers avg below threshold) → safe
# "helloworld" (no space, avg = 0.5 = threshold) → unsafe
Expand Down
Loading