Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/test_protocol_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from thriftpy2.protocol import binary as proto
from thriftpy2.thrift import TPayload, TType
from thriftpy2.utils import hexlify, serialize
from thriftpy2.transport.memory import TMemoryBuffer


class TItem(TPayload):
Expand Down Expand Up @@ -173,6 +174,56 @@ def test_write_huge_struct():
proto.TBinaryProtocol(b).write_struct(item)


@pytest.fixture
def buffer_supports_non_contiguous():
"""Pypy 3.9 and 3.10 feature BytesIO supporting non-contiguous input data."""
b = BytesIO()
try:
b.write(memoryview(b"abcd")[::-1])
except BufferError:
return False
return True


def test_write_memoryview(buffer_supports_non_contiguous):
# contiguous 8-bit items
b = TMemoryBuffer()
data = memoryview(b"hello world!\x01")
proto.write_val(b, TType.BINARY, data)
b.flush()
assert "00 00 00 0d 68 65 6c 6c 6f 20 77 6f 72 6c 64 21 01" == \
hexlify(b.getvalue())

# not 8-bit items
b = TMemoryBuffer()
data = memoryview(b"0000111122223333").cast("h")
proto.write_val(b, TType.BINARY, data)
b.flush()
assert "00 00 00 10 30 30 30 30 31 31 31 31 32 32 32 32 33 33 33 33" == \
hexlify(b.getvalue())

# not contiguous
b = TMemoryBuffer()
data = memoryview(b"0123")[::-1]
if not buffer_supports_non_contiguous:
with pytest.raises(BufferError, match="contiguous"):
proto.write_val(b, TType.BINARY, data)
else:
proto.write_val(b, TType.BINARY, data)
b.flush()
assert "00 00 00 04 33 32 31 30" == \
hexlify(b.getvalue())


def test_write_bytearray():
b = TMemoryBuffer()
proto.write_val(b, TType.BINARY, bytearray("hello world!", "utf-8"))
b.flush()
assert "00 00 00 0c 68 65 6c 6c 6f 20 77 6f 72 6c 64 21" == \
hexlify(b.getvalue())



@pytest.mark.skipif(not _compat.CYTHON, reason="cybin required")
def test_string_binary_equivalency():
from thriftpy2.protocol.binary import TBinaryProtocolFactory
Expand Down
32 changes: 32 additions & 0 deletions tests/test_protocol_cybinary.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,38 @@ def test_write_string():
hexlify(b.getvalue())


def test_write_memoryview():
# contiguous 8-bit items
b = TCyMemoryBuffer()
data = memoryview(b"hello world!\x01")
proto.write_val(b, TType.BINARY, data)
b.flush()
assert "00 00 00 0d 68 65 6c 6c 6f 20 77 6f 72 6c 64 21 01" == \
hexlify(b.getvalue())

# not 8-bit items
b = TCyMemoryBuffer()
data = memoryview(b"0000111122223333").cast("h")
proto.write_val(b, TType.BINARY, data)
b.flush()
assert "00 00 00 10 30 30 30 30 31 31 31 31 32 32 32 32 33 33 33 33" == \
hexlify(b.getvalue())

# not contiguous
with pytest.raises(BufferError, match="contiguous"):
b = TCyMemoryBuffer()
data = memoryview(b"0123")[::-1]
proto.write_val(b, TType.BINARY, data)


def test_write_bytearray():
b = TCyMemoryBuffer()
proto.write_val(b, TType.BINARY, bytearray("hello world!", "utf-8"))
b.flush()
assert "00 00 00 0c 68 65 6c 6c 6f 20 77 6f 72 6c 64 21" == \
hexlify(b.getvalue())


def test_read_string():
b = TCyMemoryBuffer(b"\x00\x00\x00\x0c"
b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c")
Expand Down
6 changes: 4 additions & 2 deletions thriftpy2/protocol/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ def write_val(outbuf, ttype, val, spec=None):
outbuf.write(pack_double(val))

elif ttype in BIN_TYPES:
if not isinstance(val, bytes):
if isinstance(val, str):
val = val.encode('utf-8')
outbuf.write(pack_string(val))
val = memoryview(val)
outbuf.write(pack_i32(val.nbytes))
outbuf.write(val)

elif ttype == TType.SET or ttype == TType.LIST:
if isinstance(spec, tuple):
Expand Down
14 changes: 12 additions & 2 deletions thriftpy2/protocol/cybin/cybin.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import sys
from libc.stdlib cimport free, malloc
from libc.stdint cimport int16_t, int32_t, int64_t
from libc.string cimport memcpy
from cpython cimport bool
from cpython cimport bool, PyObject_GetBuffer, PyBuffer_Release, PyBUF_ANY_CONTIGUOUS, PyBUF_SIMPLE

from thriftpy2.thrift import TDecodeException
from thriftpy2.transport.cybase cimport CyTransportBase, STACK_STRING_LEN
Expand Down Expand Up @@ -135,6 +135,16 @@ cdef inline write_string(CyTransportBase buf, bytes val):
buf.c_write(<char*>val, val_len)


cdef inline write_buffer(CyTransportBase buf, val):
cdef Py_buffer in_buffer
PyObject_GetBuffer(val, &in_buffer, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
try:
write_i32(buf, in_buffer.len)
buf.c_write(<char *>in_buffer.buf, in_buffer.len)
finally:
PyBuffer_Release(&in_buffer)


cdef inline write_dict(CyTransportBase buf, object val, spec):
cdef int val_len
cdef TType v_type, k_type
Expand Down Expand Up @@ -387,7 +397,7 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None):
elif ttype == T_BINARY:
if isinstance(val, str):
val = val.encode()
write_string(buf, val)
write_buffer(buf, val)

elif ttype == T_STRING:
if not isinstance(val, bytes):
Expand Down
2 changes: 1 addition & 1 deletion thriftpy2/transport/cybase.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cdef class CyTransportBase(object):
cdef object trans

cdef c_read(self, int sz, char* out)
cdef c_write(self, char* data, int sz)
cdef c_write(self, const char* data, int sz)
cdef c_flush(self)

cdef get_string(self, int sz)
2 changes: 1 addition & 1 deletion thriftpy2/transport/cybase.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ cdef class CyTransportBase(object):
cdef c_read(self, int sz, char* out):
pass

cdef c_write(self, char* data, int sz):
cdef c_write(self, const char* data, int sz):
pass

cdef c_flush(self):
Expand Down
12 changes: 10 additions & 2 deletions thriftpy2/transport/memory/cymemory.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from thriftpy2.transport.cybase cimport (
CyTransportBase,
DEFAULT_BUFFER,
)
from cpython cimport bool, PyObject_GetBuffer, PyBuffer_Release, PyBUF_ANY_CONTIGUOUS, PyBUF_SIMPLE


cdef class TCyMemoryBuffer(CyTransportBase):
Expand Down Expand Up @@ -37,6 +38,14 @@ cdef class TCyMemoryBuffer(CyTransportBase):
if r == -1:
raise MemoryError("Write to memory error")

def c_write_buffer(self, const unsigned char[::1] data):
cdef Py_buffer in_buffer
PyObject_GetBuffer(data, &in_buffer, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
try:
self.c_write(<char *>in_buffer.buf, in_buffer.len)
finally:
PyBuffer_Release(&in_buffer)

cdef _getvalue(self):
cdef char *out
cdef int size = self.buf.data_size
Expand All @@ -62,8 +71,7 @@ cdef class TCyMemoryBuffer(CyTransportBase):
if isinstance(data, unicode):
data = (<unicode>data).encode('utf-8')

cdef int sz = len(data)
return self.c_write(data, sz)
return self.c_write_buffer(data)

def is_open(self):
return True
Expand Down
Loading