Skip to content
1 change: 1 addition & 0 deletions rosbridge_server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ if(BUILD_TESTING)
test/websocket/transient_local_publisher.test.py
test/websocket/best_effort_publisher.test.py
test/websocket/multiple_subscribers_raw.test.py
test/websocket/tornado_settings.test.py
)

foreach(TEST_FILE ${TEST_FILES})
Expand Down
3 changes: 3 additions & 0 deletions rosbridge_server/scripts/rosbridge_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def _handle_parameters(self) -> None:
self.tornado_settings["websocket_ping_timeout"] = (
self.get_parameter("websocket_ping_timeout").get_parameter_value().double_value
)
self.tornado_settings["websocket_max_message_size"] = (
self.get_parameter("max_message_size").get_parameter_value().integer_value
)

# WebSocket handler parameters
self.use_compression = (
Expand Down
13 changes: 11 additions & 2 deletions rosbridge_server/test/websocket/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@ class TestClientProtocol(WebSocketClientProtocol):
"""Set message_handler to handle messages received from the server."""

message_handler: Callable[[Any], None]
on_close_handler: Callable[[bool, int, str], None]

def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
self.connected_future: Future[None] = Future()
self.message_handler = lambda _: None
self.on_close_handler = lambda _, __, ___: None
super().__init__(*args, **kwargs)

def onOpen(self) -> None:
self.connected_future.set_result(None)

def onClose(self, wasClean: bool, code: int, reason: str) -> None:
self.on_close_handler(wasClean, code, reason)

def sendJson(self, msg_dict: dict[str, Any], *, times: int = 1) -> None:
msg = json.dumps(msg_dict).encode("utf-8")
for _ in range(times):
Expand Down Expand Up @@ -124,14 +129,18 @@ def run_websocket_test(
executor.add_node(node)

async def task() -> None:
await test_fn(node, lambda: connect_to_server(node))
reactor.callFromThread(reactor.stop) # type: ignore[attr-defined]
try:
await test_fn(node, lambda: connect_to_server(node))
finally:
reactor.stop()

future = executor.create_task(task)

reactor.callInThread(executor.spin_until_future_complete, future) # type: ignore[attr-defined]
reactor.run(installSignalHandlers=False) # type: ignore[attr-defined]

future.result()

executor.remove_node(node)
node.destroy_node()
rclpy.shutdown(context=context)
Expand Down
64 changes: 64 additions & 0 deletions rosbridge_server/test/websocket/tornado_settings.test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import sys
import unittest
from pathlib import Path
from typing import TYPE_CHECKING

from twisted.python import log

sys.path.append(str(Path(__file__).parent)) # enable importing from common.py in this directory

import common
from common import sleep, websocket_test

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable

from common import TestClientProtocol
from rclpy.node import Node

log.startLogging(sys.stderr)

generate_test_description = common.generate_test_description


class TestTornadoSettings(unittest.TestCase):
@websocket_test
async def test_tornado_settings_fails(
self, node: Node, make_client: Callable[[], Awaitable[TestClientProtocol]]
) -> None:
failed_code = 0

def on_close_handler(_wasClean: bool, code: int, _reason: str) -> None:
nonlocal failed_code
failed_code = code

ws_client = await make_client()
ws_client.on_close_handler = on_close_handler

ws_client.sendJson(
{
"op": "call_service",
"type": "rosbridge_test_msgs/TestArrayRequest",
"service": "/test_service",
"args": {
"int_values": [0] * 330000
}, # max default size is 1000000, but because it's sent in uint32 the literal size is lower.
}
)

await sleep(node, 1.0)
self.assertEqual(failed_code, 0)

ws_client.sendJson(
{
"op": "call_service",
"type": "rosbridge_test_msgs/TestArrayRequest",
"service": "/test_service",
"args": {"int_values": [0] * 335000},
}
)

await sleep(node, 1.0)
self.assertEqual(failed_code, 1009)
Loading