diff --git a/tests/test_apache_json.py b/tests/test_apache_json.py index 85844e1..2b47b81 100644 --- a/tests/test_apache_json.py +++ b/tests/test_apache_json.py @@ -5,6 +5,7 @@ import sys import time from multiprocessing import Process +from io import BytesIO import pytest import six @@ -17,6 +18,7 @@ make_client as make_rpc_client from thriftpy2.thrift import TProcessor, TType from thriftpy2.transport import TMemoryBuffer +from thriftpy2.transport.base import TTransportBase, TTransportException from thriftpy2.transport.buffered import TBufferedTransportFactory @@ -35,11 +37,57 @@ def recursive_vars(obj): return recursive_vars(vars(obj)) -def test_thrift_transport(): - test_thrift = thriftpy2.load( +class ChunkedMemoryTransport(TTransportBase): + def __init__(self, value, chunk_size): + self._buffer = BytesIO(value) + self._chunk_size = chunk_size + + def is_open(self): + return True + + def open(self): + pass + + def close(self): + self._buffer.close() + + def _read(self, sz): + return self._buffer.read(min(sz, self._chunk_size)) + + def write(self, buf): + self._buffer.write(buf) + + def flush(self): + pass + + +def load_test_thrift(module_name="test_thrift"): + return thriftpy2.load( "apache_json_test.thrift", - module_name="test_thrift" + module_name=module_name ) + + +def make_request_bytes(test_thrift, obj, seqid=0): + args = test_thrift.TestService.test_args(test=obj) + buf = TMemoryBuffer() + oprot = TApacheJSONProtocolFactory().get_protocol(buf) + oprot.write_message_begin("test", 1, seqid) + oprot.write_struct(args) + return buf.getvalue() + + +def read_request_args(test_thrift, data, chunk_size=4096): + trans = ChunkedMemoryTransport(data, chunk_size) + iprot = TApacheJSONProtocolFactory().get_protocol(trans) + assert iprot.read_message_begin() == ("test", 1, 0) + args = iprot.read_struct(test_thrift.TestService.test_args()) + iprot.read_message_end() + return args + + +def test_thrift_transport(): + test_thrift = load_test_thrift() Test = test_thrift.Test Foo = test_thrift.Foo test_object = Test( @@ -131,10 +179,7 @@ def test(t): @pytest.mark.parametrize('server_func', [(make_rpc_server, make_rpc_client), (make_http_server, make_http_client)]) def test_client(server_func): - test_thrift = thriftpy2.load( - "apache_json_test.thrift", - module_name="test_thrift" - ) + test_thrift = load_test_thrift() class Handler: @staticmethod @@ -177,3 +222,41 @@ def run_server(): finally: proc.terminate() time.sleep(1) + + +def test_load_data_handles_chunked_messages(): + test_thrift = load_test_thrift(module_name="chunked_messages_test_thrift") + first = test_thrift.Test( + tstr='δ½ ε₯½ \\\\ "quoted" [brackets] πŸ˜€πŸš€πŸŽ‰', + tlist_of_strings=['["nested"]', 'slash\\\\quote\\"', 'δΈ­ζ–‡πŸ˜€πŸŽ‰'], + ) + second = test_thrift.Test(tstr="第二村 πŸ˜€πŸš€πŸŽ‰") + transport = ChunkedMemoryTransport( + b" \n\t" + make_request_bytes(test_thrift, first) + make_request_bytes(test_thrift, second), + chunk_size=7, + ) + iprot = TApacheJSONProtocolFactory().get_protocol(transport) + + assert iprot.read_message_begin() == ("test", 1, 0) + first_args = iprot.read_struct(test_thrift.TestService.test_args()) + iprot.read_message_end() + + assert iprot.read_message_begin() == ("test", 1, 0) + second_args = iprot.read_struct(test_thrift.TestService.test_args()) + iprot.read_message_end() + + assert recursive_vars(first_args.test) == recursive_vars(first) + assert recursive_vars(second_args.test) == recursive_vars(second) + + +def test_load_data_raises_eof_for_truncated_payload(): + test_thrift = load_test_thrift(module_name="truncated_test_thrift") + request_data = make_request_bytes(test_thrift, test_thrift.Test(tstr="truncated"))[:-1] + iprot = TApacheJSONProtocolFactory().get_protocol( + ChunkedMemoryTransport(request_data, chunk_size=4) + ) + + with pytest.raises(TTransportException) as exc_info: + iprot.read_message_begin() + + assert exc_info.value.type == TTransportException.END_OF_FILE diff --git a/thriftpy2/protocol/apache_json.py b/thriftpy2/protocol/apache_json.py index 625a628..657f82d 100644 --- a/thriftpy2/protocol/apache_json.py +++ b/thriftpy2/protocol/apache_json.py @@ -6,6 +6,7 @@ """ from __future__ import absolute_import +import codecs import json import base64 import sys @@ -13,7 +14,9 @@ import six from thriftpy2.protocol import TProtocolBase +from thriftpy2.protocol.exc import TProtocolException from thriftpy2.thrift import TType +from thriftpy2.transport.base import TTransportException CTYPES = { @@ -77,44 +80,82 @@ class TApacheJSONProtocol(TProtocolBase): Protocol that implements the Apache JSON Protocol """ + READ_CHUNK_SIZE = 4096 + def __init__(self, trans): TProtocolBase.__init__(self, trans) self._req = None + self._decoder = json.JSONDecoder() + self._input_text = "" + self._utf8_decoder = codecs.getincrementaldecoder("utf-8")() + + def _read_chunk(self, sz): + read = getattr(self.trans, "_read", None) + if read is None: + # read(sz) may block until full on buffered transports, so only probe one byte here. + return self.trans.read(1) + return read(sz) + + def _decode_chunk(self, chunk, final=False): + try: + return self._utf8_decoder.decode(chunk, final=final) + except UnicodeDecodeError as exc: + raise TProtocolException( + type=TProtocolException.INVALID_DATA, + message="Bad UTF-8 data in Apache JSON payload: {}".format(exc) + ) + + def _try_parse_buffer(self): + stripped = self._input_text.lstrip() + if not stripped: + return False + + try: + self._req, end = self._decoder.raw_decode(stripped) + except ValueError: + return False + + self._input_text = stripped[end:] + return True + + def _raise_parse_error(self): + stripped = self._input_text.lstrip() + if not stripped: + self._req = None + self._input_text = "" + return + + try: + self._decoder.raw_decode(stripped) + except ValueError as exc: + pos = getattr(exc, "pos", None) + if pos is None or pos >= len(stripped): + raise TTransportException( + TTransportException.END_OF_FILE, + "End of file reading Apache JSON payload" + ) + raise TProtocolException( + type=TProtocolException.INVALID_DATA, + message="Invalid Apache JSON payload: {}".format(exc) + ) def _load_data(self): - data = b"" - l_braces = 0 - in_string = False while True: - # read(sz) will wait until it has read exactly sz bytes, - # so we must read until we get a balanced json list in absence of knowing - # how long the json string will be - if hasattr(self.trans, 'getvalue'): - try: - data = self.trans.getvalue() - break - except Exception: - pass - new_data = self.trans.read(1) - data += new_data - if new_data == b'"' and not data.endswith(b'\\"'): - in_string = not in_string - if not in_string: - if new_data == b"[": - l_braces += 1 - elif new_data == b"]": - l_braces -= 1 - if l_braces == 0: - break - if data: - self._req = json.loads(data.decode('utf8')) - else: - self._req = None + if self._try_parse_buffer(): + return + + new_data = self._read_chunk(self.READ_CHUNK_SIZE) + if not new_data: + self._input_text += self._decode_chunk(b"", final=True) + self._raise_parse_error() + return + + self._input_text += self._decode_chunk(new_data) def read_message_begin(self): if not self._req: self._load_data() - return self._req[1:4] + return tuple(self._req[1:4]) def read_message_end(self): pass @@ -316,4 +357,6 @@ def read_struct(self, obj): :param obj: :return: """ - return self._dict_to_thrift(self._req[4], obj) + result = self._dict_to_thrift(self._req[4], obj) + self._req = None + return result