Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 90 additions & 7 deletions tests/test_apache_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import time
from multiprocessing import Process
from io import BytesIO

import pytest
import six
Expand All @@ -17,6 +18,7 @@
make_client as make_rpc_client
from thriftpy2.thrift import TProcessor, TType
from thriftpy2.transport import TMemoryBuffer
from thriftpy2.transport.base import TTransportBase, TTransportException
from thriftpy2.transport.buffered import TBufferedTransportFactory


Expand All @@ -35,11 +37,57 @@ def recursive_vars(obj):
return recursive_vars(vars(obj))


def test_thrift_transport():
test_thrift = thriftpy2.load(
class ChunkedMemoryTransport(TTransportBase):
def __init__(self, value, chunk_size):
self._buffer = BytesIO(value)
self._chunk_size = chunk_size

def is_open(self):
return True

def open(self):
pass

def close(self):
self._buffer.close()

def _read(self, sz):
return self._buffer.read(min(sz, self._chunk_size))

def write(self, buf):
self._buffer.write(buf)

def flush(self):
pass


def load_test_thrift(module_name="test_thrift"):
return thriftpy2.load(
"apache_json_test.thrift",
module_name="test_thrift"
module_name=module_name
)


def make_request_bytes(test_thrift, obj, seqid=0):
args = test_thrift.TestService.test_args(test=obj)
buf = TMemoryBuffer()
oprot = TApacheJSONProtocolFactory().get_protocol(buf)
oprot.write_message_begin("test", 1, seqid)
oprot.write_struct(args)
return buf.getvalue()


def read_request_args(test_thrift, data, chunk_size=4096):
trans = ChunkedMemoryTransport(data, chunk_size)
iprot = TApacheJSONProtocolFactory().get_protocol(trans)
assert iprot.read_message_begin() == ("test", 1, 0)
args = iprot.read_struct(test_thrift.TestService.test_args())
iprot.read_message_end()
return args


def test_thrift_transport():
test_thrift = load_test_thrift()
Test = test_thrift.Test
Foo = test_thrift.Foo
test_object = Test(
Expand Down Expand Up @@ -131,10 +179,7 @@ def test(t):
@pytest.mark.parametrize('server_func', [(make_rpc_server, make_rpc_client),
(make_http_server, make_http_client)])
def test_client(server_func):
test_thrift = thriftpy2.load(
"apache_json_test.thrift",
module_name="test_thrift"
)
test_thrift = load_test_thrift()

class Handler:
@staticmethod
Expand Down Expand Up @@ -177,3 +222,41 @@ def run_server():
finally:
proc.terminate()
time.sleep(1)


def test_load_data_handles_chunked_messages():
test_thrift = load_test_thrift(module_name="chunked_messages_test_thrift")
first = test_thrift.Test(
tstr='你好 \\\\ "quoted" [brackets] 😀🚀🎉',
tlist_of_strings=['["nested"]', 'slash\\\\quote\\"', '中文😀🎉'],
)
second = test_thrift.Test(tstr="第二条 😀🚀🎉")
transport = ChunkedMemoryTransport(
b" \n\t" + make_request_bytes(test_thrift, first) + make_request_bytes(test_thrift, second),
chunk_size=7,
)
iprot = TApacheJSONProtocolFactory().get_protocol(transport)

assert iprot.read_message_begin() == ("test", 1, 0)
first_args = iprot.read_struct(test_thrift.TestService.test_args())
iprot.read_message_end()

assert iprot.read_message_begin() == ("test", 1, 0)
second_args = iprot.read_struct(test_thrift.TestService.test_args())
iprot.read_message_end()

assert recursive_vars(first_args.test) == recursive_vars(first)
assert recursive_vars(second_args.test) == recursive_vars(second)


def test_load_data_raises_eof_for_truncated_payload():
test_thrift = load_test_thrift(module_name="truncated_test_thrift")
request_data = make_request_bytes(test_thrift, test_thrift.Test(tstr="truncated"))[:-1]
iprot = TApacheJSONProtocolFactory().get_protocol(
ChunkedMemoryTransport(request_data, chunk_size=4)
)

with pytest.raises(TTransportException) as exc_info:
iprot.read_message_begin()

assert exc_info.value.type == TTransportException.END_OF_FILE
101 changes: 72 additions & 29 deletions thriftpy2/protocol/apache_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
"""

from __future__ import absolute_import
import codecs
import json
import base64
import sys

import six

from thriftpy2.protocol import TProtocolBase
from thriftpy2.protocol.exc import TProtocolException
from thriftpy2.thrift import TType
from thriftpy2.transport.base import TTransportException


CTYPES = {
Expand Down Expand Up @@ -77,44 +80,82 @@ class TApacheJSONProtocol(TProtocolBase):
Protocol that implements the Apache JSON Protocol
"""

READ_CHUNK_SIZE = 4096

def __init__(self, trans):
TProtocolBase.__init__(self, trans)
self._req = None
self._decoder = json.JSONDecoder()
self._input_text = ""
self._utf8_decoder = codecs.getincrementaldecoder("utf-8")()

def _read_chunk(self, sz):
read = getattr(self.trans, "_read", None)
if read is None:
# read(sz) may block until full on buffered transports, so only probe one byte here.
return self.trans.read(1)
return read(sz)

def _decode_chunk(self, chunk, final=False):
try:
return self._utf8_decoder.decode(chunk, final=final)
except UnicodeDecodeError as exc:
raise TProtocolException(
type=TProtocolException.INVALID_DATA,
message="Bad UTF-8 data in Apache JSON payload: {}".format(exc)
)

def _try_parse_buffer(self):
stripped = self._input_text.lstrip()
if not stripped:
return False

try:
self._req, end = self._decoder.raw_decode(stripped)
except ValueError:
return False

self._input_text = stripped[end:]
return True

def _raise_parse_error(self):
stripped = self._input_text.lstrip()
if not stripped:
self._req = None
self._input_text = ""
return

try:
self._decoder.raw_decode(stripped)
except ValueError as exc:
pos = getattr(exc, "pos", None)
if pos is None or pos >= len(stripped):
raise TTransportException(
TTransportException.END_OF_FILE,
"End of file reading Apache JSON payload"
)
raise TProtocolException(
type=TProtocolException.INVALID_DATA,
message="Invalid Apache JSON payload: {}".format(exc)
)

def _load_data(self):
data = b""
l_braces = 0
in_string = False
while True:
# read(sz) will wait until it has read exactly sz bytes,
# so we must read until we get a balanced json list in absence of knowing
# how long the json string will be
if hasattr(self.trans, 'getvalue'):
try:
data = self.trans.getvalue()
break
except Exception:
pass
new_data = self.trans.read(1)
data += new_data
if new_data == b'"' and not data.endswith(b'\\"'):
in_string = not in_string
if not in_string:
if new_data == b"[":
l_braces += 1
elif new_data == b"]":
l_braces -= 1
if l_braces == 0:
break
if data:
self._req = json.loads(data.decode('utf8'))
else:
self._req = None
if self._try_parse_buffer():
return

new_data = self._read_chunk(self.READ_CHUNK_SIZE)
if not new_data:
self._input_text += self._decode_chunk(b"", final=True)
self._raise_parse_error()
return

self._input_text += self._decode_chunk(new_data)

def read_message_begin(self):
if not self._req:
self._load_data()
return self._req[1:4]
return tuple(self._req[1:4])

def read_message_end(self):
pass
Expand Down Expand Up @@ -316,4 +357,6 @@ def read_struct(self, obj):
:param obj:
:return:
"""
return self._dict_to_thrift(self._req[4], obj)
result = self._dict_to_thrift(self._req[4], obj)
self._req = None
return result
Loading