diff --git a/lib/java/src/main/java/org/apache/thrift/TUnion.java b/lib/java/src/main/java/org/apache/thrift/TUnion.java index 65c92dc5672..2ff07cdfa96 100644 --- a/lib/java/src/main/java/org/apache/thrift/TUnion.java +++ b/lib/java/src/main/java/org/apache/thrift/TUnion.java @@ -231,21 +231,26 @@ public void read(TProtocol iprot, TUnion struct) throws TException { struct.setField_ = null; struct.value_ = null; - iprot.readStructBegin(); - - TField field = iprot.readFieldBegin(); - - struct.value_ = struct.standardSchemeReadValue(iprot, field); - if (struct.value_ != null) { - struct.setField_ = struct.enumForId(field.id); + iprot.incrementRecursionDepth(); + try { + iprot.readStructBegin(); + + TField field = iprot.readFieldBegin(); + + struct.value_ = struct.standardSchemeReadValue(iprot, field); + if (struct.value_ != null) { + struct.setField_ = struct.enumForId(field.id); + } + + iprot.readFieldEnd(); + // this is so that we will eat the stop byte. we could put a check here to + // make sure that it actually *is* the stop byte, but it's faster to do it + // this way. + iprot.readFieldBegin(); + iprot.readStructEnd(); + } finally { + iprot.decrementRecursionDepth(); } - - iprot.readFieldEnd(); - // this is so that we will eat the stop byte. we could put a check here to - // make sure that it actually *is* the stop byte, but it's faster to do it - // this way. - iprot.readFieldBegin(); - iprot.readStructEnd(); } @Override @@ -253,12 +258,17 @@ public void write(TProtocol oprot, TUnion struct) throws TException { if (struct.getSetField() == null || struct.getFieldValue() == null) { throw new TProtocolException("Cannot write a TUnion with no set value!"); } - oprot.writeStructBegin(struct.getStructDesc()); - oprot.writeFieldBegin(struct.getFieldDesc(struct.setField_)); - struct.standardSchemeWriteValue(oprot); - oprot.writeFieldEnd(); - oprot.writeFieldStop(); - oprot.writeStructEnd(); + oprot.incrementRecursionDepth(); + try { + oprot.writeStructBegin(struct.getStructDesc()); + oprot.writeFieldBegin(struct.getFieldDesc(struct.setField_)); + struct.standardSchemeWriteValue(oprot); + oprot.writeFieldEnd(); + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } finally { + oprot.decrementRecursionDepth(); + } } } @@ -274,10 +284,15 @@ private static class TUnionTupleScheme extends TupleScheme { public void read(TProtocol iprot, TUnion struct) throws TException { struct.setField_ = null; struct.value_ = null; - short fieldID = iprot.readI16(); - struct.value_ = struct.tupleSchemeReadValue(iprot, fieldID); - if (struct.value_ != null) { - struct.setField_ = struct.enumForId(fieldID); + iprot.incrementRecursionDepth(); + try { + short fieldID = iprot.readI16(); + struct.value_ = struct.tupleSchemeReadValue(iprot, fieldID); + if (struct.value_ != null) { + struct.setField_ = struct.enumForId(fieldID); + } + } finally { + iprot.decrementRecursionDepth(); } } @@ -286,8 +301,13 @@ public void write(TProtocol oprot, TUnion struct) throws TException { if (struct.getSetField() == null || struct.getFieldValue() == null) { throw new TProtocolException("Cannot write a TUnion with no set value!"); } - oprot.writeI16(struct.setField_.getThriftFieldId()); - struct.tupleSchemeWriteValue(oprot); + oprot.incrementRecursionDepth(); + try { + oprot.writeI16(struct.setField_.getThriftFieldId()); + struct.tupleSchemeWriteValue(oprot); + } finally { + oprot.decrementRecursionDepth(); + } } } } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java index b91c4cfb16c..d1a07683474 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java @@ -153,9 +153,14 @@ public final void writeField(TField field, WriteCallback callback) throws } public final void writeStruct(TStruct struct, WriteCallback callback) throws TException { - writeStructBegin(struct); - callback.call(null); - writeStructEnd(); + incrementRecursionDepth(); + try { + writeStructBegin(struct); + callback.call(null); + writeStructEnd(); + } finally { + decrementRecursionDepth(); + } } public final void writeMessage(TMessage message, WriteCallback callback) throws TException { @@ -190,10 +195,15 @@ public final T readMessage(ReadCallback callback) throws TExcep * @throws TException when any sub-operation failed */ public final T readStruct(ReadCallback callback) throws TException { - TStruct tStruct = readStructBegin(); - T t = callback.accept(tStruct); - readStructEnd(); - return t; + incrementRecursionDepth(); + try { + TStruct tStruct = readStructBegin(); + T t = callback.accept(tStruct); + readStructEnd(); + return t; + } finally { + decrementRecursionDepth(); + } } /** diff --git a/lib/kotlin/build.gradle.kts b/lib/kotlin/build.gradle.kts index 81203af191c..d954704f7cb 100644 --- a/lib/kotlin/build.gradle.kts +++ b/lib/kotlin/build.gradle.kts @@ -79,7 +79,31 @@ tasks { group = LifecycleBasePlugin.BUILD_GROUP } - compileKotlin { dependsOn("compileThrift") } + task("compileThriftRecursion") { + val thriftBin = + if (hasProperty("thrift.compiler")) { + file(property("thrift.compiler")!!) + } else { + project.rootDir.resolve("../../compiler/cpp/thrift") + } + val outputDir = layout.buildDirectory.dir("generated-sources") + doFirst { mkdir(outputDir) } + commandLine = + listOf( + thriftBin.absolutePath, + "-gen", + "kotlin", + "-out", + outputDir.get().toString(), + layout.projectDirectory + .file("src/test/resources/RecursionDepthTest.thrift") + .asFile + .absolutePath, + ) + group = LifecycleBasePlugin.BUILD_GROUP + } + + compileKotlin { dependsOn("compileThrift", "compileThriftRecursion") } } sourceSets["main"].java { srcDir(layout.buildDirectory.dir("generated-sources")) } diff --git a/lib/kotlin/src/test/kotlin/org/apache/thrift/RecursionDepthTest.kt b/lib/kotlin/src/test/kotlin/org/apache/thrift/RecursionDepthTest.kt new file mode 100644 index 00000000000..7e9bad5911c --- /dev/null +++ b/lib/kotlin/src/test/kotlin/org/apache/thrift/RecursionDepthTest.kt @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift + +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.protocol.TProtocolException +import org.apache.thrift.recursion.CoRec +import org.apache.thrift.recursion.CoRec2 +import org.apache.thrift.recursion.RecError +import org.apache.thrift.recursion.RecTree +import org.apache.thrift.recursion.RecUnion +import org.apache.thrift.transport.TMemoryBuffer +import org.apache.thrift.transport.TMemoryInputTransport +import org.junit.jupiter.api.Test + +/** + * Exercises the recursion-depth limit through the *generated* struct read/write path (TBase.read / + * TBase.write) -- the real code path that deeply nested input exercises -- rather than calling + * TProtocol.incrementRecursionDepth / decrementRecursionDepth in isolation. + * + * The recursive types CoRec / CoRec2 / RecTree / RecUnion / RecError are generated from + * src/test/resources/RecursionDepthTest.thrift (mirroring test/Recursive.thrift): CoRec <-> CoRec2 + * form a mutually recursive chain, RecTree is a wide tree, and RecUnion / RecError are + * self-recursive union and exception types. + * + * Struct and exception serialization is routed through TProtocol's readStruct {} / writeStruct {} + * helpers, which is where the bound lives; recursive unions are read/written by TUnion, which + * bounds them the same way. The limit is taken from TConfiguration.getRecursionLimit(), so each + * test uses a small custom limit for clarity. + */ +internal class RecursionDepthTest { + + private val limit = 8 + + private fun config(recursionLimit: Int) = + TConfiguration.custom().setRecursionLimit(recursionLimit).build() + + // Build a CoRec/CoRec2 chain that is exactly 'depth' structs deep. + private fun makeNestedRecs(depth: Int): CoRec? = + if (depth <= 0) null else CoRec(makeNestedCoRec2(depth - 1)) + + private fun makeNestedCoRec2(depth: Int): CoRec2? = + if (depth <= 0) null else CoRec2(makeNestedRecs(depth - 1)) + + // Build a RecUnion chain that is exactly 'depth' unions deep; the innermost union holds the + // non-recursive 'leaf' so the value is finite and writable. + private fun makeNestedUnion(depth: Int): RecUnion = + if (depth <= 1) RecUnion(RecUnion._Fields.LEAF, 0.toShort()) + else RecUnion(RecUnion._Fields.CHILD, makeNestedUnion(depth - 1)) + + // Build a RecError chain that is exactly 'depth' exceptions deep. + private fun makeNestedError(depth: Int): RecError = + if (depth <= 1) RecError(leaf = 0) else RecError(child = makeNestedError(depth - 1)) + + // Serialize via the generated write() over a protocol with the given limit. + private fun serialize(data: TBase<*, *>, recursionLimit: Int): ByteArray { + val buf = TMemoryBuffer(config(recursionLimit), 1024) + data.write(TBinaryProtocol(buf)) + return buf.array.copyOf(buf.length()) + } + + // Deserialize via the generated read() over a protocol with the given limit. + private fun > deserialize(into: T, bytes: ByteArray, recursionLimit: Int): T { + into.read(TBinaryProtocol(TMemoryInputTransport(config(recursionLimit), bytes))) + return into + } + + @Test + fun roundTripOneBelowLimitSucceeds() { + val bytes = serialize(makeNestedRecs(limit - 1)!!, limit) + deserialize(CoRec(), bytes, limit) + } + + @Test + fun roundTripAtLimitSucceeds() { + val bytes = serialize(makeNestedRecs(limit)!!, limit) + deserialize(CoRec(), bytes, limit) + } + + @Test + fun writeOneOverLimitThrows() { + val ex = + assertFailsWith { serialize(makeNestedRecs(limit + 1)!!, limit) } + assertEquals(TProtocolException.DEPTH_LIMIT, ex.type) + } + + @Test + fun readOneOverLimitThrows() { + // Produce a valid over-limit payload with a higher write limit, then read it + // back with the real limit -- mimicking a message arriving from the network. + val bytes = serialize(makeNestedRecs(limit + 1)!!, limit + 1) + val ex = assertFailsWith { deserialize(CoRec(), bytes, limit) } + assertEquals(TProtocolException.DEPTH_LIMIT, ex.type) + } + + @Test + fun wideStructureRoundTrips() { + // Many siblings (>> limit) must still round-trip: this only holds if the + // counter is decremented for each sibling back to depth 1. + val children = + (0 until limit * 3).map { RecTree(children = emptyList(), item = it.toShort()) } + val bytes = serialize(RecTree(children = children, item = 0.toShort()), limit) + deserialize(RecTree(), bytes, limit) + } + + @Test + fun cyclicGraphThrows() { + val data = makeNestedRecs(2)!! // CoRec -> CoRec2 -> null + data.other!!.other = data // close the loop: CoRec2.other -> CoRec + val ex = assertFailsWith { serialize(data, limit) } + assertEquals(TProtocolException.DEPTH_LIMIT, ex.type) + } + + @Test + fun unionRoundTripAtLimitSucceeds() { + val bytes = serialize(makeNestedUnion(limit), limit) + deserialize(RecUnion(), bytes, limit) + } + + @Test + fun unionWriteOneOverLimitThrows() { + val ex = + assertFailsWith { serialize(makeNestedUnion(limit + 1), limit) } + assertEquals(TProtocolException.DEPTH_LIMIT, ex.type) + } + + @Test + fun unionReadOneOverLimitThrows() { + val bytes = serialize(makeNestedUnion(limit + 1), limit + 1) + val ex = assertFailsWith { deserialize(RecUnion(), bytes, limit) } + assertEquals(TProtocolException.DEPTH_LIMIT, ex.type) + } + + @Test + fun exceptionRoundTripAtLimitSucceeds() { + val bytes = serialize(makeNestedError(limit), limit) + deserialize(RecError(), bytes, limit) + } + + @Test + fun exceptionWriteOneOverLimitThrows() { + val ex = + assertFailsWith { serialize(makeNestedError(limit + 1), limit) } + assertEquals(TProtocolException.DEPTH_LIMIT, ex.type) + } + + @Test + fun exceptionReadOneOverLimitThrows() { + val bytes = serialize(makeNestedError(limit + 1), limit + 1) + val ex = assertFailsWith { deserialize(RecError(), bytes, limit) } + assertEquals(TProtocolException.DEPTH_LIMIT, ex.type) + } +} diff --git a/lib/kotlin/src/test/resources/RecursionDepthTest.thrift b/lib/kotlin/src/test/resources/RecursionDepthTest.thrift new file mode 100644 index 00000000000..92e1e16e4a1 --- /dev/null +++ b/lib/kotlin/src/test/resources/RecursionDepthTest.thrift @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Recursive types mirroring test/Recursive.thrift, used by RecursionDepthTest +// to drive the generated struct/union/exception read/write path. CoRec <-> +// CoRec2 form a mutually recursive chain; RecTree is a wide tree of nested +// structs; RecUnion and RecError are self-recursive union and exception types, +// each carrying a non-recursive leaf so a finite value can be constructed. + +namespace java org.apache.thrift.recursion + +struct CoRec { + 1: CoRec2 other +} + +struct CoRec2 { + 1: CoRec other +} + +struct RecTree { + 1: list children + 2: i16 item +} + +union RecUnion { + 1: RecUnion child + 2: i16 leaf +} + +exception RecError { + 1: RecError child + 2: i16 leaf +}