Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions compiler/cpp/src/thrift/generate/t_dart_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,9 @@ void t_dart_generator::generate_dart_struct_reader(ostream& out, t_struct* tstru

// Declare stack tmp variables and read struct header
indent(out) << "TField field;" << '\n';
indent(out) << "iprot.incrementRecursionDepth();" << '\n';
indent(out) << "try";
scope_up(out);
indent(out) << "iprot.readStructBegin();" << '\n';

// Loop over reading in fields
Expand Down Expand Up @@ -963,7 +966,12 @@ void t_dart_generator::generate_dart_struct_reader(ostream& out, t_struct* tstru
// performs various checks (e.g. check that all required fields are set)
indent(out) << "validate();" << '\n';

scope_down(out, "\n\n");
scope_down(out, " finally"); // close try, begin finally
scope_up(out);
indent(out) << "iprot.decrementRecursionDepth();" << '\n';
scope_down(out); // close finally

scope_down(out, "\n\n"); // close read() function
}

// generates dart method to perform various checks
Expand Down Expand Up @@ -1029,6 +1037,9 @@ void t_dart_generator::generate_dart_struct_writer(ostream& out, t_struct* tstru
// performs various checks (e.g. check that all required fields are set)
indent(out) << "validate();" << '\n' << '\n';

indent(out) << "oprot.incrementRecursionDepth();" << '\n';
indent(out) << "try";
scope_up(out);
indent(out) << "oprot.writeStructBegin(_STRUCT_DESC);" << '\n';

for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
Expand Down Expand Up @@ -1064,7 +1075,12 @@ void t_dart_generator::generate_dart_struct_writer(ostream& out, t_struct* tstru
indent(out) << "oprot.writeFieldStop();" << '\n' << indent() << "oprot.writeStructEnd();"
<< '\n';

scope_down(out, "\n\n");
scope_down(out, " finally"); // close try, begin finally
scope_up(out);
indent(out) << "oprot.decrementRecursionDepth();" << '\n';
scope_down(out); // close finally

scope_down(out, "\n\n"); // close write() function
}

/**
Expand All @@ -1082,6 +1098,9 @@ void t_dart_generator::generate_dart_struct_result_writer(ostream& out, t_struct
const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter;

indent(out) << "oprot.incrementRecursionDepth();" << '\n';
indent(out) << "try";
scope_up(out);
indent(out) << "oprot.writeStructBegin(_STRUCT_DESC);" << '\n' << '\n';

bool first = true;
Expand Down Expand Up @@ -1113,7 +1132,12 @@ void t_dart_generator::generate_dart_struct_result_writer(ostream& out, t_struct
indent(out) << "oprot.writeFieldStop();" << '\n' << indent()
<< "oprot.writeStructEnd();" << '\n';

scope_down(out, "\n\n");
scope_down(out, " finally"); // close try, begin finally
scope_up(out);
indent(out) << "oprot.decrementRecursionDepth();" << '\n';
scope_down(out); // close finally

scope_down(out, "\n\n"); // close write() function
}

void t_dart_generator::generate_generic_field_getters(std::ostream& out,
Expand Down
15 changes: 15 additions & 0 deletions lib/dart/lib/src/protocol/t_protocol.dart
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,23 @@ part of thrift;
abstract class TProtocol {
final TTransport transport;

int _recursionDepth = 0;
static const int _defaultRecursionDepth = 64;

TProtocol(this.transport);

void incrementRecursionDepth() {
if (_recursionDepth >= _defaultRecursionDepth) {
throw TProtocolError(
TProtocolErrorType.DEPTH_LIMIT, "Maximum recursion depth exceeded");
}
_recursionDepth++;
}

void decrementRecursionDepth() {
_recursionDepth--;
}

/// Write
void writeMessageBegin(TMessage message);
void writeMessageEnd();
Expand Down
16 changes: 15 additions & 1 deletion test/dart/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ precross: stubs

check: stubs

# Recursion-depth regression test (THRIFT-6056). Kept out of the default
# cross-test chain because it needs a null-safe Dart SDK (>= 2.12); run it
# explicitly with: make recursion-test
gen-dart/Recursive/lib/Recursive.dart: ../Recursive.thrift
$(THRIFT) --gen dart ../Recursive.thrift

recursion-stubs: gen-dart/Recursive/lib/Recursive.dart
cd gen-dart/Recursive; ${DARTPUB} get
cd recursion_depth_test; ${DARTPUB} get

recursion-test: recursion-stubs
cd recursion_depth_test; ${DART} test

clean-local:
$(RM) -r gen-dart/ test_client/.pub
find . -type d -name ".dart_tool" | xargs $(RM) -r
Expand All @@ -55,4 +68,5 @@ distdir:
$(MAKE) $(AM_MAKEFLAGS) distdir-am

EXTRA_DIST = \
test_client
test_client \
recursion_depth_test
34 changes: 34 additions & 0 deletions test/dart/recursion_depth_test/pubspec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# 'License'); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

name: recursion_depth_test
version: 0.24.0
description: Recursion-depth limit regression test for the Dart Thrift library
author: Apache Thrift Developers <dev@thrift.apache.org>
homepage: http://thrift.apache.org

environment:
sdk: ">=2.12.0 <4.0.0"

dependencies:
thrift:
path: ../../../lib/dart
Recursive:
path: ../gen-dart/Recursive

dev_dependencies:
test: ">=0.12.30 <2.0.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Exercises the recursion-depth limit through the *generated* struct read/write
// path (TBase.write / TBase.read) -- the real code path that deeply nested
// input exercises -- rather than calling TProtocol.incrementRecursionDepth /
// decrementRecursionDepth in isolation.
//
// The recursive IDL types come from test/Recursive.thrift, generated into the
// `Recursive` package: CoRec <-> CoRec2 form a mutually recursive chain and
// RecTree is a wide tree of nested structs.
//
// Dart hard-codes the limit at TProtocol's _defaultRecursionDepth (64), so the
// boundary is: a chain of 64 structs round-trips, 65 is rejected with
// DEPTH_LIMIT.

library thrift.test.recursion_depth_test;

import 'dart:typed_data' show Uint8List;

import 'package:test/test.dart';
import 'package:thrift/thrift.dart';
import 'package:Recursive/Recursive.dart';

// The hard recursion limit baked into TProtocol (_defaultRecursionDepth).
const int kRecursionLimit = 64;

// Build a CoRec/CoRec2 chain that is exactly [depth] structs deep.
CoRec? makeNestedRecs(int depth) =>
depth <= 0 ? null : (CoRec()..other = makeNestedCoRec2(depth - 1));

CoRec2? makeNestedCoRec2(int depth) =>
depth <= 0 ? null : (CoRec2()..other = makeNestedRecs(depth - 1));

// Build a CoError/CoError2 exception chain that is exactly [depth] structs deep.
// Exceptions are read/written through the same generated path as structs, so
// the same bound applies. The chain terminates in a null field, which the
// generated write() skips, so [depth] structs nest exactly [depth] levels.
CoError? makeNestedError(int depth) =>
depth <= 0 ? null : (CoError()..other = makeNestedError2(depth - 1));

CoError2? makeNestedError2(int depth) =>
depth <= 0 ? null : (CoError2()..other = makeNestedError(depth - 1));

// Serialize via the generated write() over a fresh protocol of the given kind.
Uint8List writeWith(TBase obj, TProtocolFactory factory) =>
TSerializer(protocolFactory: factory).write(obj);

// Deserialize via the generated read() over a fresh protocol of the given kind.
void readWith(TBase into, Uint8List bytes, TProtocolFactory factory) =>
TDeserializer(protocolFactory: factory).read(into, bytes);

// Craft a [depth]-deep nested-struct payload with raw protocol primitives,
// bypassing the recursion counter (which lives in the generated write()). This
// is the only way to obtain an over-limit payload, since a normal write() of
// such a chain would itself be rejected at the limit.
Uint8List craftDeepChain(TProtocolFactory factory, int depth) {
final transport = TBufferedTransport();
final protocol = factory.getProtocol(transport);

void emit(int d) {
protocol.writeStructBegin(TStruct('CoRec'));
if (d > 1) {
protocol.writeFieldBegin(TField('other', TType.STRUCT, 1));
emit(d - 1);
protocol.writeFieldEnd();
}
protocol.writeFieldStop();
protocol.writeStructEnd();
}

emit(depth);
return transport.consumeWriteBuffer();
}

// Craft a [depth]-deep nested CoError payload with raw protocol primitives.
// Uses the real recursive field (id 1, type STRUCT) so the reader recurses
// through the guarded generated read(), not skip().
Uint8List craftDeepErrorChain(TProtocolFactory factory, int depth) {
final transport = TBufferedTransport();
final protocol = factory.getProtocol(transport);

void emit(int d) {
protocol.writeStructBegin(TStruct('CoError'));
if (d > 1) {
protocol.writeFieldBegin(TField('other', TType.STRUCT, 1));
emit(d - 1);
protocol.writeFieldEnd();
}
protocol.writeFieldStop();
protocol.writeStructEnd();
}

emit(depth);
return transport.consumeWriteBuffer();
}

final Matcher throwsDepthLimit = throwsA(predicate(
(e) => e is TProtocolError && e.type == TProtocolErrorType.DEPTH_LIMIT));

void main() {
final factories = <String, TProtocolFactory>{
'binary': TBinaryProtocolFactory(),
'compact': TCompactProtocolFactory(),
'json': TJsonProtocolFactory(),
};

factories.forEach((name, factory) {
group('$name protocol', () {
// A chain one level below the limit must round-trip cleanly.
test('round-trips a chain one below the limit', () {
final data = makeNestedRecs(kRecursionLimit - 1)!;
final bytes = writeWith(data, factory);
readWith(CoRec(), bytes, factory);
});

// A chain exactly at the limit must still round-trip (off-by-one guard).
test('round-trips a chain exactly at the limit', () {
final data = makeNestedRecs(kRecursionLimit)!;
final bytes = writeWith(data, factory);
readWith(CoRec(), bytes, factory);
});

// Writing a chain one level over the limit must be rejected.
test('rejects writing a chain above the limit', () {
final data = makeNestedRecs(kRecursionLimit + 1)!;
expect(() => writeWith(data, factory), throwsDepthLimit);
});

// Reading a too-deep payload must be rejected.
test('rejects reading a payload above the limit', () {
final bytes = craftDeepChain(factory, kRecursionLimit + 1);
expect(() => readWith(CoRec(), bytes, factory), throwsDepthLimit);
});

// Decrement regression guard: a wide (shallow) tree whose total number of
// struct-begins far exceeds the limit must still round-trip. This only
// holds if decrementRecursionDepth() unwinds each sibling back to depth 1.
test('round-trips a wide (shallow) structure', () {
final tree = RecTree()
..item = 0
..children = <RecTree>[];
for (var i = 0; i < (kRecursionLimit * 3); i++) {
tree.children!.add(RecTree()
..item = i
..children = <RecTree>[]);
}
final bytes = writeWith(tree, factory);
readWith(RecTree(), bytes, factory);
});

// A cyclic object graph would recurse forever without the limit; it must
// instead fail with DEPTH_LIMIT.
test('rejects a cyclic object graph', () {
final data = makeNestedRecs(2)!; // CoRec -> CoRec2 -> null
data.other!.other = data; // close the loop: CoRec2.other -> CoRec
expect(() => writeWith(data, factory), throwsDepthLimit);
});

// The same bound must apply to recursive exceptions.
test('round-trips an exception exactly at the limit', () {
final data = makeNestedError(kRecursionLimit)!;
final bytes = writeWith(data, factory);
readWith(CoError(), bytes, factory);
});

test('rejects writing an exception above the limit', () {
final data = makeNestedError(kRecursionLimit + 1)!;
expect(() => writeWith(data, factory), throwsDepthLimit);
});

test('rejects reading an exception payload above the limit', () {
final bytes = craftDeepErrorChain(factory, kRecursionLimit + 1);
expect(() => readWith(CoError(), bytes, factory), throwsDepthLimit);
});
});
});
}
Loading