diff --git a/tests/test_protocol_binary.py b/tests/test_protocol_binary.py index d4cf41f2..89020c58 100644 --- a/tests/test_protocol_binary.py +++ b/tests/test_protocol_binary.py @@ -9,6 +9,7 @@ from thriftpy2.protocol import binary as proto from thriftpy2.thrift import TPayload, TType from thriftpy2.utils import hexlify, serialize +from thriftpy2.transport.memory import TMemoryBuffer class TItem(TPayload): @@ -173,6 +174,56 @@ def test_write_huge_struct(): proto.TBinaryProtocol(b).write_struct(item) +@pytest.fixture +def buffer_supports_non_contiguous(): + """Pypy 3.9 and 3.10 feature BytesIO supporting non-contiguous input data.""" + b = BytesIO() + try: + b.write(memoryview(b"abcd")[::-1]) + except BufferError: + return False + return True + + +def test_write_memoryview(buffer_supports_non_contiguous): + # contiguous 8-bit items + b = TMemoryBuffer() + data = memoryview(b"hello world!\x01") + proto.write_val(b, TType.BINARY, data) + b.flush() + assert "00 00 00 0d 68 65 6c 6c 6f 20 77 6f 72 6c 64 21 01" == \ + hexlify(b.getvalue()) + + # not 8-bit items + b = TMemoryBuffer() + data = memoryview(b"0000111122223333").cast("h") + proto.write_val(b, TType.BINARY, data) + b.flush() + assert "00 00 00 10 30 30 30 30 31 31 31 31 32 32 32 32 33 33 33 33" == \ + hexlify(b.getvalue()) + + # not contiguous + b = TMemoryBuffer() + data = memoryview(b"0123")[::-1] + if not buffer_supports_non_contiguous: + with pytest.raises(BufferError, match="contiguous"): + proto.write_val(b, TType.BINARY, data) + else: + proto.write_val(b, TType.BINARY, data) + b.flush() + assert "00 00 00 04 33 32 31 30" == \ + hexlify(b.getvalue()) + + +def test_write_bytearray(): + b = TMemoryBuffer() + proto.write_val(b, TType.BINARY, bytearray("hello world!", "utf-8")) + b.flush() + assert "00 00 00 0c 68 65 6c 6c 6f 20 77 6f 72 6c 64 21" == \ + hexlify(b.getvalue()) + + + @pytest.mark.skipif(not _compat.CYTHON, reason="cybin required") def test_string_binary_equivalency(): from thriftpy2.protocol.binary import TBinaryProtocolFactory diff --git a/tests/test_protocol_cybinary.py b/tests/test_protocol_cybinary.py index 4bc8f7e6..5a09a8f8 100644 --- a/tests/test_protocol_cybinary.py +++ b/tests/test_protocol_cybinary.py @@ -133,6 +133,38 @@ def test_write_string(): hexlify(b.getvalue()) +def test_write_memoryview(): + # contiguous 8-bit items + b = TCyMemoryBuffer() + data = memoryview(b"hello world!\x01") + proto.write_val(b, TType.BINARY, data) + b.flush() + assert "00 00 00 0d 68 65 6c 6c 6f 20 77 6f 72 6c 64 21 01" == \ + hexlify(b.getvalue()) + + # not 8-bit items + b = TCyMemoryBuffer() + data = memoryview(b"0000111122223333").cast("h") + proto.write_val(b, TType.BINARY, data) + b.flush() + assert "00 00 00 10 30 30 30 30 31 31 31 31 32 32 32 32 33 33 33 33" == \ + hexlify(b.getvalue()) + + # not contiguous + with pytest.raises(BufferError, match="contiguous"): + b = TCyMemoryBuffer() + data = memoryview(b"0123")[::-1] + proto.write_val(b, TType.BINARY, data) + + +def test_write_bytearray(): + b = TCyMemoryBuffer() + proto.write_val(b, TType.BINARY, bytearray("hello world!", "utf-8")) + b.flush() + assert "00 00 00 0c 68 65 6c 6c 6f 20 77 6f 72 6c 64 21" == \ + hexlify(b.getvalue()) + + def test_read_string(): b = TCyMemoryBuffer(b"\x00\x00\x00\x0c" b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c") diff --git a/thriftpy2/protocol/binary.py b/thriftpy2/protocol/binary.py index fbffa291..55073e3b 100644 --- a/thriftpy2/protocol/binary.py +++ b/thriftpy2/protocol/binary.py @@ -119,9 +119,11 @@ def write_val(outbuf, ttype, val, spec=None): outbuf.write(pack_double(val)) elif ttype in BIN_TYPES: - if not isinstance(val, bytes): + if isinstance(val, str): val = val.encode('utf-8') - outbuf.write(pack_string(val)) + val = memoryview(val) + outbuf.write(pack_i32(val.nbytes)) + outbuf.write(val) elif ttype == TType.SET or ttype == TType.LIST: if isinstance(spec, tuple): diff --git a/thriftpy2/protocol/cybin/cybin.pyx b/thriftpy2/protocol/cybin/cybin.pyx index 19ccd317..48e97508 100644 --- a/thriftpy2/protocol/cybin/cybin.pyx +++ b/thriftpy2/protocol/cybin/cybin.pyx @@ -5,7 +5,7 @@ import sys from libc.stdlib cimport free, malloc from libc.stdint cimport int16_t, int32_t, int64_t from libc.string cimport memcpy -from cpython cimport bool +from cpython cimport bool, PyObject_GetBuffer, PyBuffer_Release, PyBUF_ANY_CONTIGUOUS, PyBUF_SIMPLE from thriftpy2.thrift import TDecodeException from thriftpy2.transport.cybase cimport CyTransportBase, STACK_STRING_LEN @@ -135,6 +135,16 @@ cdef inline write_string(CyTransportBase buf, bytes val): buf.c_write(val, val_len) +cdef inline write_buffer(CyTransportBase buf, val): + cdef Py_buffer in_buffer + PyObject_GetBuffer(val, &in_buffer, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS) + try: + write_i32(buf, in_buffer.len) + buf.c_write(in_buffer.buf, in_buffer.len) + finally: + PyBuffer_Release(&in_buffer) + + cdef inline write_dict(CyTransportBase buf, object val, spec): cdef int val_len cdef TType v_type, k_type @@ -387,7 +397,7 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None): elif ttype == T_BINARY: if isinstance(val, str): val = val.encode() - write_string(buf, val) + write_buffer(buf, val) elif ttype == T_STRING: if not isinstance(val, bytes): diff --git a/thriftpy2/transport/cybase.pxd b/thriftpy2/transport/cybase.pxd index 81586668..bdddd2cb 100644 --- a/thriftpy2/transport/cybase.pxd +++ b/thriftpy2/transport/cybase.pxd @@ -20,7 +20,7 @@ cdef class CyTransportBase(object): cdef object trans cdef c_read(self, int sz, char* out) - cdef c_write(self, char* data, int sz) + cdef c_write(self, const char* data, int sz) cdef c_flush(self) cdef get_string(self, int sz) diff --git a/thriftpy2/transport/cybase.pyx b/thriftpy2/transport/cybase.pyx index 192ccabf..517a56ac 100644 --- a/thriftpy2/transport/cybase.pyx +++ b/thriftpy2/transport/cybase.pyx @@ -108,7 +108,7 @@ cdef class CyTransportBase(object): cdef c_read(self, int sz, char* out): pass - cdef c_write(self, char* data, int sz): + cdef c_write(self, const char* data, int sz): pass cdef c_flush(self): diff --git a/thriftpy2/transport/memory/cymemory.pyx b/thriftpy2/transport/memory/cymemory.pyx index 7dc8ae65..95d70e23 100644 --- a/thriftpy2/transport/memory/cymemory.pyx +++ b/thriftpy2/transport/memory/cymemory.pyx @@ -7,6 +7,7 @@ from thriftpy2.transport.cybase cimport ( CyTransportBase, DEFAULT_BUFFER, ) +from cpython cimport bool, PyObject_GetBuffer, PyBuffer_Release, PyBUF_ANY_CONTIGUOUS, PyBUF_SIMPLE cdef class TCyMemoryBuffer(CyTransportBase): @@ -37,6 +38,14 @@ cdef class TCyMemoryBuffer(CyTransportBase): if r == -1: raise MemoryError("Write to memory error") + def c_write_buffer(self, const unsigned char[::1] data): + cdef Py_buffer in_buffer + PyObject_GetBuffer(data, &in_buffer, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS) + try: + self.c_write(in_buffer.buf, in_buffer.len) + finally: + PyBuffer_Release(&in_buffer) + cdef _getvalue(self): cdef char *out cdef int size = self.buf.data_size @@ -62,8 +71,7 @@ cdef class TCyMemoryBuffer(CyTransportBase): if isinstance(data, unicode): data = (data).encode('utf-8') - cdef int sz = len(data) - return self.c_write(data, sz) + return self.c_write_buffer(data) def is_open(self): return True