diff --git a/compiler/cpp/src/thrift/generate/t_dart_generator.cc b/compiler/cpp/src/thrift/generate/t_dart_generator.cc index 34f9d82e242..cb28bf54ae4 100644 --- a/compiler/cpp/src/thrift/generate/t_dart_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_dart_generator.cc @@ -889,6 +889,9 @@ void t_dart_generator::generate_dart_struct_reader(ostream& out, t_struct* tstru // Declare stack tmp variables and read struct header indent(out) << "TField field;" << '\n'; + indent(out) << "iprot.incrementRecursionDepth();" << '\n'; + indent(out) << "try"; + scope_up(out); indent(out) << "iprot.readStructBegin();" << '\n'; // Loop over reading in fields @@ -963,7 +966,12 @@ void t_dart_generator::generate_dart_struct_reader(ostream& out, t_struct* tstru // performs various checks (e.g. check that all required fields are set) indent(out) << "validate();" << '\n'; - scope_down(out, "\n\n"); + scope_down(out, " finally"); // close try, begin finally + scope_up(out); + indent(out) << "iprot.decrementRecursionDepth();" << '\n'; + scope_down(out); // close finally + + scope_down(out, "\n\n"); // close read() function } // generates dart method to perform various checks @@ -1029,6 +1037,9 @@ void t_dart_generator::generate_dart_struct_writer(ostream& out, t_struct* tstru // performs various checks (e.g. check that all required fields are set) indent(out) << "validate();" << '\n' << '\n'; + indent(out) << "oprot.incrementRecursionDepth();" << '\n'; + indent(out) << "try"; + scope_up(out); indent(out) << "oprot.writeStructBegin(_STRUCT_DESC);" << '\n'; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { @@ -1064,7 +1075,12 @@ void t_dart_generator::generate_dart_struct_writer(ostream& out, t_struct* tstru indent(out) << "oprot.writeFieldStop();" << '\n' << indent() << "oprot.writeStructEnd();" << '\n'; - scope_down(out, "\n\n"); + scope_down(out, " finally"); // close try, begin finally + scope_up(out); + indent(out) << "oprot.decrementRecursionDepth();" << '\n'; + scope_down(out); // close finally + + scope_down(out, "\n\n"); // close write() function } /** @@ -1082,6 +1098,9 @@ void t_dart_generator::generate_dart_struct_result_writer(ostream& out, t_struct const vector& fields = tstruct->get_sorted_members(); vector::const_iterator f_iter; + indent(out) << "oprot.incrementRecursionDepth();" << '\n'; + indent(out) << "try"; + scope_up(out); indent(out) << "oprot.writeStructBegin(_STRUCT_DESC);" << '\n' << '\n'; bool first = true; @@ -1113,7 +1132,12 @@ void t_dart_generator::generate_dart_struct_result_writer(ostream& out, t_struct indent(out) << "oprot.writeFieldStop();" << '\n' << indent() << "oprot.writeStructEnd();" << '\n'; - scope_down(out, "\n\n"); + scope_down(out, " finally"); // close try, begin finally + scope_up(out); + indent(out) << "oprot.decrementRecursionDepth();" << '\n'; + scope_down(out); // close finally + + scope_down(out, "\n\n"); // close write() function } void t_dart_generator::generate_generic_field_getters(std::ostream& out, diff --git a/lib/dart/lib/src/protocol/t_protocol.dart b/lib/dart/lib/src/protocol/t_protocol.dart index f49c0321d76..7dfda870779 100644 --- a/lib/dart/lib/src/protocol/t_protocol.dart +++ b/lib/dart/lib/src/protocol/t_protocol.dart @@ -20,8 +20,23 @@ part of thrift; abstract class TProtocol { final TTransport transport; + int _recursionDepth = 0; + static const int _defaultRecursionDepth = 64; + TProtocol(this.transport); + void incrementRecursionDepth() { + if (_recursionDepth >= _defaultRecursionDepth) { + throw TProtocolError( + TProtocolErrorType.DEPTH_LIMIT, "Maximum recursion depth exceeded"); + } + _recursionDepth++; + } + + void decrementRecursionDepth() { + _recursionDepth--; + } + /// Write void writeMessageBegin(TMessage message); void writeMessageEnd(); diff --git a/test/dart/Makefile.am b/test/dart/Makefile.am index 835ac4a595e..cb00f5e1226 100644 --- a/test/dart/Makefile.am +++ b/test/dart/Makefile.am @@ -35,6 +35,19 @@ precross: stubs check: stubs +# Recursion-depth regression test (THRIFT-6056). Kept out of the default +# cross-test chain because it needs a null-safe Dart SDK (>= 2.12); run it +# explicitly with: make recursion-test +gen-dart/Recursive/lib/Recursive.dart: ../Recursive.thrift + $(THRIFT) --gen dart ../Recursive.thrift + +recursion-stubs: gen-dart/Recursive/lib/Recursive.dart + cd gen-dart/Recursive; ${DARTPUB} get + cd recursion_depth_test; ${DARTPUB} get + +recursion-test: recursion-stubs + cd recursion_depth_test; ${DART} test + clean-local: $(RM) -r gen-dart/ test_client/.pub find . -type d -name ".dart_tool" | xargs $(RM) -r @@ -55,4 +68,5 @@ distdir: $(MAKE) $(AM_MAKEFLAGS) distdir-am EXTRA_DIST = \ - test_client + test_client \ + recursion_depth_test diff --git a/test/dart/recursion_depth_test/pubspec.yaml b/test/dart/recursion_depth_test/pubspec.yaml new file mode 100644 index 00000000000..2148f8d529b --- /dev/null +++ b/test/dart/recursion_depth_test/pubspec.yaml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: recursion_depth_test +version: 0.24.0 +description: Recursion-depth limit regression test for the Dart Thrift library +author: Apache Thrift Developers +homepage: http://thrift.apache.org + +environment: + sdk: ">=2.12.0 <4.0.0" + +dependencies: + thrift: + path: ../../../lib/dart + Recursive: + path: ../gen-dart/Recursive + +dev_dependencies: + test: ">=0.12.30 <2.0.0" diff --git a/test/dart/recursion_depth_test/test/t_protocol_recursion_depth_test.dart b/test/dart/recursion_depth_test/test/t_protocol_recursion_depth_test.dart new file mode 100644 index 00000000000..07bfe103fdb --- /dev/null +++ b/test/dart/recursion_depth_test/test/t_protocol_recursion_depth_test.dart @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Exercises the recursion-depth limit through the *generated* struct read/write +// path (TBase.write / TBase.read) -- the real code path that deeply nested +// input exercises -- rather than calling TProtocol.incrementRecursionDepth / +// decrementRecursionDepth in isolation. +// +// The recursive IDL types come from test/Recursive.thrift, generated into the +// `Recursive` package: CoRec <-> CoRec2 form a mutually recursive chain and +// RecTree is a wide tree of nested structs. +// +// Dart hard-codes the limit at TProtocol's _defaultRecursionDepth (64), so the +// boundary is: a chain of 64 structs round-trips, 65 is rejected with +// DEPTH_LIMIT. + +library thrift.test.recursion_depth_test; + +import 'dart:typed_data' show Uint8List; + +import 'package:test/test.dart'; +import 'package:thrift/thrift.dart'; +import 'package:Recursive/Recursive.dart'; + +// The hard recursion limit baked into TProtocol (_defaultRecursionDepth). +const int kRecursionLimit = 64; + +// Build a CoRec/CoRec2 chain that is exactly [depth] structs deep. +CoRec? makeNestedRecs(int depth) => + depth <= 0 ? null : (CoRec()..other = makeNestedCoRec2(depth - 1)); + +CoRec2? makeNestedCoRec2(int depth) => + depth <= 0 ? null : (CoRec2()..other = makeNestedRecs(depth - 1)); + +// Build a CoError/CoError2 exception chain that is exactly [depth] structs deep. +// Exceptions are read/written through the same generated path as structs, so +// the same bound applies. The chain terminates in a null field, which the +// generated write() skips, so [depth] structs nest exactly [depth] levels. +CoError? makeNestedError(int depth) => + depth <= 0 ? null : (CoError()..other = makeNestedError2(depth - 1)); + +CoError2? makeNestedError2(int depth) => + depth <= 0 ? null : (CoError2()..other = makeNestedError(depth - 1)); + +// Serialize via the generated write() over a fresh protocol of the given kind. +Uint8List writeWith(TBase obj, TProtocolFactory factory) => + TSerializer(protocolFactory: factory).write(obj); + +// Deserialize via the generated read() over a fresh protocol of the given kind. +void readWith(TBase into, Uint8List bytes, TProtocolFactory factory) => + TDeserializer(protocolFactory: factory).read(into, bytes); + +// Craft a [depth]-deep nested-struct payload with raw protocol primitives, +// bypassing the recursion counter (which lives in the generated write()). This +// is the only way to obtain an over-limit payload, since a normal write() of +// such a chain would itself be rejected at the limit. +Uint8List craftDeepChain(TProtocolFactory factory, int depth) { + final transport = TBufferedTransport(); + final protocol = factory.getProtocol(transport); + + void emit(int d) { + protocol.writeStructBegin(TStruct('CoRec')); + if (d > 1) { + protocol.writeFieldBegin(TField('other', TType.STRUCT, 1)); + emit(d - 1); + protocol.writeFieldEnd(); + } + protocol.writeFieldStop(); + protocol.writeStructEnd(); + } + + emit(depth); + return transport.consumeWriteBuffer(); +} + +// Craft a [depth]-deep nested CoError payload with raw protocol primitives. +// Uses the real recursive field (id 1, type STRUCT) so the reader recurses +// through the guarded generated read(), not skip(). +Uint8List craftDeepErrorChain(TProtocolFactory factory, int depth) { + final transport = TBufferedTransport(); + final protocol = factory.getProtocol(transport); + + void emit(int d) { + protocol.writeStructBegin(TStruct('CoError')); + if (d > 1) { + protocol.writeFieldBegin(TField('other', TType.STRUCT, 1)); + emit(d - 1); + protocol.writeFieldEnd(); + } + protocol.writeFieldStop(); + protocol.writeStructEnd(); + } + + emit(depth); + return transport.consumeWriteBuffer(); +} + +final Matcher throwsDepthLimit = throwsA(predicate( + (e) => e is TProtocolError && e.type == TProtocolErrorType.DEPTH_LIMIT)); + +void main() { + final factories = { + 'binary': TBinaryProtocolFactory(), + 'compact': TCompactProtocolFactory(), + 'json': TJsonProtocolFactory(), + }; + + factories.forEach((name, factory) { + group('$name protocol', () { + // A chain one level below the limit must round-trip cleanly. + test('round-trips a chain one below the limit', () { + final data = makeNestedRecs(kRecursionLimit - 1)!; + final bytes = writeWith(data, factory); + readWith(CoRec(), bytes, factory); + }); + + // A chain exactly at the limit must still round-trip (off-by-one guard). + test('round-trips a chain exactly at the limit', () { + final data = makeNestedRecs(kRecursionLimit)!; + final bytes = writeWith(data, factory); + readWith(CoRec(), bytes, factory); + }); + + // Writing a chain one level over the limit must be rejected. + test('rejects writing a chain above the limit', () { + final data = makeNestedRecs(kRecursionLimit + 1)!; + expect(() => writeWith(data, factory), throwsDepthLimit); + }); + + // Reading a too-deep payload must be rejected. + test('rejects reading a payload above the limit', () { + final bytes = craftDeepChain(factory, kRecursionLimit + 1); + expect(() => readWith(CoRec(), bytes, factory), throwsDepthLimit); + }); + + // Decrement regression guard: a wide (shallow) tree whose total number of + // struct-begins far exceeds the limit must still round-trip. This only + // holds if decrementRecursionDepth() unwinds each sibling back to depth 1. + test('round-trips a wide (shallow) structure', () { + final tree = RecTree() + ..item = 0 + ..children = []; + for (var i = 0; i < (kRecursionLimit * 3); i++) { + tree.children!.add(RecTree() + ..item = i + ..children = []); + } + final bytes = writeWith(tree, factory); + readWith(RecTree(), bytes, factory); + }); + + // A cyclic object graph would recurse forever without the limit; it must + // instead fail with DEPTH_LIMIT. + test('rejects a cyclic object graph', () { + final data = makeNestedRecs(2)!; // CoRec -> CoRec2 -> null + data.other!.other = data; // close the loop: CoRec2.other -> CoRec + expect(() => writeWith(data, factory), throwsDepthLimit); + }); + + // The same bound must apply to recursive exceptions. + test('round-trips an exception exactly at the limit', () { + final data = makeNestedError(kRecursionLimit)!; + final bytes = writeWith(data, factory); + readWith(CoError(), bytes, factory); + }); + + test('rejects writing an exception above the limit', () { + final data = makeNestedError(kRecursionLimit + 1)!; + expect(() => writeWith(data, factory), throwsDepthLimit); + }); + + test('rejects reading an exception payload above the limit', () { + final bytes = craftDeepErrorChain(factory, kRecursionLimit + 1); + expect(() => readWith(CoError(), bytes, factory), throwsDepthLimit); + }); + }); + }); +}