diff --git a/framework/py/flwr/supernode/runtime/run_clientapp.py b/framework/py/flwr/supernode/runtime/run_clientapp.py index b358f8e3e64f..dfa8679c8beb 100644 --- a/framework/py/flwr/supernode/runtime/run_clientapp.py +++ b/framework/py/flwr/supernode/runtime/run_clientapp.py @@ -251,9 +251,11 @@ def pull_task_input(stub: ClientAppIoStub) -> tuple[Message, Context, Run, Fab]: return_type=Message, ) - # Set the message ID - # The deflated message doesn't contain the message_id (its own object_id) - message.metadata.__dict__["_message_id"] = object_tree.object_id + # Set the message ID from the transport message so replies refer to the + # instruction message tracked by the SuperNode. + message.metadata.__dict__["_message_id"] = pull_msg_res.messages_list[ + 0 + ].metadata.message_id return message, context, run, fab except grpc.RpcError as e: log(ERROR, "[PullTaskInput] gRPC error occurred: %s", str(e)) diff --git a/framework/py/flwr/supernode/runtime/run_clientapp_test.py b/framework/py/flwr/supernode/runtime/run_clientapp_test.py index 7fa987188358..52ee573a83ed 100644 --- a/framework/py/flwr/supernode/runtime/run_clientapp_test.py +++ b/framework/py/flwr/supernode/runtime/run_clientapp_test.py @@ -16,17 +16,20 @@ import unittest -from unittest.mock import patch +from types import SimpleNamespace +from unittest.mock import Mock, patch import grpc +from flwr.app import Message, Metadata, RecordDict +from flwr.app.message import make_message from flwr.common.exit import ExitCode from flwr.supercore.interceptors import ( AppIoTokenClientInterceptor, RuntimeVersionClientInterceptor, ) -from .run_clientapp import run_clientapp +from .run_clientapp import pull_task_input, run_clientapp class TestRunClientApp(unittest.TestCase): @@ -72,3 +75,62 @@ def test_run_clientapp_exits_nonzero_on_grpc_error(self) -> None: flwr_exit.call_args.kwargs["code"], ExitCode.CLIENTAPP_COMMUNICATION_ERROR, ) + + def test_pull_task_input_preserves_transport_message_id(self) -> None: + """`pull_task_input` should preserve the SuperNode-tracked message ID.""" + stub = Mock() + context = SimpleNamespace(run_id=1, node_id=2) + run = object() + fab = object() + stub.PullTaskInput.return_value = SimpleNamespace( + context=object(), + run=object(), + fab=object(), + ) + stub.PullMessage.return_value = SimpleNamespace( + messages_list=[ + SimpleNamespace( + metadata=SimpleNamespace(message_id="instruction-message-id") + ) + ], + message_object_trees=[SimpleNamespace(object_id="object-tree-id")], + ) + inflated_message = make_message( + Metadata( + run_id=1, + message_id="object-tree-id", + src_node_id=0, + dst_node_id=2, + reply_to_message_id="", + group_id="", + created_at=0.0, + ttl=1.0, + message_type="train", + ), + RecordDict(), + ) + + with ( + patch( + "flwr.supernode.runtime.run_clientapp.context_from_proto", + return_value=context, + ), + patch( + "flwr.supernode.runtime.run_clientapp.run_from_proto", return_value=run + ), + patch( + "flwr.supernode.runtime.run_clientapp.fab_from_proto", return_value=fab + ), + patch( + "flwr.supernode.runtime.run_clientapp.pull_and_inflate_object_from_tree", + return_value=inflated_message, + ), + ): + message, _, _, _ = pull_task_input(stub) + + reply = Message(content=RecordDict(), reply_to=message) + self.assertEqual(message.metadata.message_id, "instruction-message-id") + self.assertEqual( + reply.metadata.reply_to_message_id, + "instruction-message-id", + )