Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .agents
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ object ProtoNames {

val DESCRIPTOR_NAME = Name.identifier("DESCRIPTOR")
val MARSHALLER_NAME = Name.identifier("MARSHALLER")
val PRESENCE_INDICES_NAME = Name.identifier("PresenceIndices")
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,9 @@ internal object FirGeneratedProtoMessageBuilderPropertyKey : GeneratedDeclaratio
return "FirGeneratedProtoMessageBuilderPropertyKey"
}
}

internal object FirGeneratedProtoMessageBuilderFunctionKey : GeneratedDeclarationKey() {
override fun toString(): String {
return "FirGeneratedProtoMessageBuilderFunctionKey"
}
}
Comment on lines +53 to +57
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is actually needed, or if we could just use the PropertyKey

Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ import org.jetbrains.kotlin.fir.caches.getValue
import org.jetbrains.kotlin.fir.declarations.utils.isInterface
import org.jetbrains.kotlin.fir.extensions.*
import org.jetbrains.kotlin.fir.plugin.createCompanionObject
import org.jetbrains.kotlin.fir.plugin.createMemberFunction
import org.jetbrains.kotlin.fir.plugin.createMemberProperty
import org.jetbrains.kotlin.fir.plugin.createNestedClass
import org.jetbrains.kotlin.fir.resolve.defaultType
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.SpecialNames

Expand All @@ -34,17 +37,35 @@ class FirProtobufMessageGenerator(
@Suppress("unused")
private val logger: MessageCollector,
) : FirDeclarationGenerationExtension(session) {
private val messageCallablesCache: FirCache<FirClassSymbol<*>, Map<Name, FirCallableSymbol<*>>, Nothing?> =

private class GeneratedNames(
val propertyNames: Map<Name, FirPropertySymbol> = emptyMap(),
val functionNames: Map<Name, FirPropertySymbol> = emptyMap(),
)

private val messageCallablesCache: FirCache<FirClassSymbol<*>, GeneratedNames, Nothing?> =
session.firCachesFactory.createCache { messageClassSymbol: FirClassSymbol<*>, _ ->
buildMap {
vsApi {
messageClassSymbol.forAllCallablesVS(session) {
if (it is FirPropertySymbol) {
put(it.name, it)
val propertyNames = mutableMapOf<Name, FirPropertySymbol>()
val functionNames = mutableMapOf<Name, FirPropertySymbol>()
val presenceTrackedPropertyNames = messageClassSymbol.presenceTrackedPropertyNames(session)
vsApi {
messageClassSymbol.forAllCallablesVS(session) { it ->
if (it is FirPropertySymbol) {
propertyNames[it.name] = it

if (it.name in presenceTrackedPropertyNames) {
// if the property is presence tracked, construct clear<PropertyName> function
val functionName = Name.identifier(
"clear${
it.name.asString()
.replaceFirstChar { char -> char.uppercase() }
}")
functionNames[functionName] = it
}
}
}
}
GeneratedNames(propertyNames, functionNames)
}

override fun FirDeclarationPredicateRegistrar.registerPredicates() {
Expand Down Expand Up @@ -118,7 +139,8 @@ class FirProtobufMessageGenerator(
val messageClassSymbol = classSymbol.generatedProtoMessageBuilderKey?.message
?: return super.getCallableNamesForClass(classSymbol, context)

return messageCallablesCache.getValue(messageClassSymbol).keys
val generatedNames = messageCallablesCache.getValue(messageClassSymbol)
return generatedNames.propertyNames.keys + generatedNames.functionNames.keys
}

override fun generateProperties(
Expand All @@ -130,7 +152,8 @@ class FirProtobufMessageGenerator(
val messageClassSymbol = context.owner.generatedProtoMessageBuilderKey?.message
?: return super.generateProperties(callableId, context)

val property = messageCallablesCache.getValue(messageClassSymbol)[callableId.callableName]
val property = messageCallablesCache.getValue(messageClassSymbol)
.propertyNames[callableId.callableName]
?: return super.generateProperties(callableId, context)

return listOf(
Expand All @@ -153,4 +176,59 @@ class FirProtobufMessageGenerator(
}.symbol
)
}

override fun generateFunctions(
callableId: CallableId,
context: MemberGenerationContext?
): List<FirNamedFunctionSymbol> {
context ?: return super.generateFunctions(callableId, context)

val messageClassSymbol = context.owner.generatedProtoMessageBuilderKey?.message
?: return super.generateFunctions(callableId, context)

val property = messageCallablesCache.getValue(messageClassSymbol)
.functionNames[callableId.callableName] ?: return super.generateFunctions(callableId, context)

return listOf(
createMemberFunction(
owner = context.owner,
key = FirGeneratedProtoMessageBuilderFunctionKey,
name = callableId.callableName,
returnType = session.builtinTypes.unitType.coneType
) {
visibility = Visibilities.Public
modality = Modality.ABSTRACT
vsApi {
sourceVS = property.source?.fakeElement(KtFakeSourceElementKind.PluginGenerated)
}
}.symbol
)
}
}

/**
* Returns the names of all proto fields of this generated proto message class that are presence tracked.
*
* It finds those names by searching for the `PresenceIndices` object in the internal message class.
* The `PresenceIndices` object contains a field for each presence-tracked field.
* The field name is the same as the proto field name.
*/
private fun FirClassSymbol<*>.presenceTrackedPropertyNames(session: FirSession): Set<Name> {
val internalClass = vsApi {
session.getRegularClassSymbolByClassIdVS(
classId.internalMessageClassId()
)
} ?: return emptySet()

val presenceIndices = vsApi {
internalClass.declarationsVS(session)
.filterIsInstance<FirRegularClassSymbol>()
.find { it.classKind == ClassKind.OBJECT && it.name == ProtoNames.PRESENCE_INDICES_NAME }
} ?: return emptySet()

return vsApi {
presenceIndices.declarationsVS(session)
.filterIsInstance<FirPropertySymbol>()
.mapTo(mutableSetOf()) { it.name }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package kotlinx.rpc.codegen

import kotlinx.rpc.codegen.common.ProtoNames
import org.jetbrains.kotlin.KtSourceElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirClassLikeDeclaration
Expand Down Expand Up @@ -97,3 +98,16 @@ fun FirClassLikeDeclaration.markAsDeprecatedHidden(session: FirSession) {
replaceAnnotations(annotations + listOf(createDeprecatedHiddenAnnotation(session)))
replaceDeprecationsProvider(getDeprecationsProvider(session))
}


/**
* Returns the [ClassId] corresponding to the generated internal message class for this [ClassId].
*/
fun ClassId.internalMessageClassId(): ClassId {
val names = relativeClassName.pathSegments()
.map { Name.identifier(ProtoNames.internalName(it.asString())) }

return ClassId(packageFqName, names.first()).let { topLevel ->
names.drop(1).fold(topLevel) { acc, name -> acc.createNestedClassId(name) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import kotlinx.rpc.codegen.common.ProtoClassId
import kotlinx.rpc.codegen.common.ProtoNames
import kotlinx.rpc.codegen.common.RpcClassId
import kotlinx.rpc.codegen.doesMatchesClassId
import kotlinx.rpc.codegen.internalMessageClassId
import kotlinx.rpc.codegen.vsApi
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
Expand Down Expand Up @@ -62,12 +63,7 @@ object FirProtoMessageAnnotationChecker {
containingDeclarations = listOf()
)?.filterIsInstance<FirRegularClass>()?.toList()?.reversed() ?: emptyList()

val internalClassId = transformClassId(
declaration = declaration,
parentClasses = parentClasses,
transformer = ProtoNames::internalName,
)

val internalClassId = declaration.symbol.classId.internalMessageClassId()
val internalDeclaration = vsApi {
context.session.getRegularClassSymbolByClassIdVS(internalClassId)
}
Expand Down Expand Up @@ -124,23 +120,4 @@ object FirProtoMessageAnnotationChecker {
}
}
}

private fun transformClassId(
declaration: FirRegularClass,
parentClasses: List<FirRegularClass>,
transformer: (String) -> String,
): ClassId {
val topLevelNames = (parentClasses + listOf(declaration))
.map { Name.identifier(transformer(it.name.asString())) }

// for nested classes, we need to construct ClassId properly using createNestedClassId for each level
return ClassId(
packageFqName = declaration.symbol.classId.packageFqName,
topLevelName = topLevelNames.first()
).let {
topLevelNames.drop(1).fold(it) { acc, name ->
acc.createNestedClassId(name)
}
}
}
}
7 changes: 7 additions & 0 deletions docs/pages/kotlinx-rpc/topics/grpc-configuration.topic
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@
includeWkt = true
errorFormat = BufGenerateExtension.ErrorFormat.Json
indentSize = 4
optionalFieldOrNullGetters = true

comments {
copyComments = true
Expand Down Expand Up @@ -466,6 +467,12 @@
Controls the indentation size used by the built-in Kotlin generators.
This is a <code>kotlinx.rpc</code>-specific option, not a Buf flag.
</def>
<def title="generate.optionalFieldOrNullGetters" id="grpc-buf-generate-optional-field-or-null-getters">
Generates additional <code>fooOrNull</code> accessors for optional protobuf fields.
The main generated property keeps returning the protobuf default value when the field is absent,
while the generated <code>OrNull</code> accessor returns <code>null</code> when the field is not present.
Default value: <code>false</code>. This is a <code>kotlinx.rpc</code>-specific option, not a Buf flag.
</def>
<def title="generate.comments.copyComments" id="grpc-buf-generate-copy-comments">
Controls whether comments from the original <path>.proto</path> files are copied to generated Kotlin
sources. This is a <code>kotlinx.rpc</code>-specific option, not a Buf flag.
Expand Down
50 changes: 50 additions & 0 deletions docs/pages/kotlinx-rpc/topics/grpc-generated-code.topic
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,56 @@
</code-block>
</chapter>

<chapter title="Optional fields and clearing" id="grpc-optional-fields-and-clearing">
<p>
Optional protobuf fields are generated as non-null Kotlin properties.
When such a field is absent, the generated property returns the protobuf default value for that field.
To check whether the field is actually present, use <code>presence.has&lt;Field&gt;</code>.
Generated builders and <code>copy { ... }</code> blocks also provide
<code>clear&lt;Field&gt;()</code>, which removes the field from the message instead of assigning its default value.
</p>
<p>
If you enable
<a href="grpc-configuration.topic" anchor="grpc-buf-generate-optional-field-or-null-getters">
<code>rpc.protoc.buf.generate.optionalFieldOrNullGetters</code>
</a>,
the generator also adds <code>&lt;field&gt;OrNull</code> accessors for optional fields.
</p>
<code-block lang="protobuf">
package com.example;

message User {
optional string nickname = 1;
}
</code-block>
<code-block lang="Kotlin">
import com.example.User
import com.example.copy
import com.example.invoke

val original = User {
nickname = "neo"
}

check(original.presence.hasNickname)
check(original.nickname == "neo")

val cleared = original.copy {
clearNickname()
}

check(!cleared.presence.hasNickname)
check(cleared.nickname == "")

// Available only when optionalFieldOrNullGetters is enabled:
check(cleared.nicknameOrNull == null)
</code-block>
<p>
Clearing a field affects subsequent serialization as well: once cleared, the field is absent from newly
encoded data and stays absent after decode.
</p>
</chapter>

<chapter title="Enums" id="grpc-enums">
<p>
Proto enums become Kotlin sealed types. Unknown values received over the wire
Expand Down
19 changes: 19 additions & 0 deletions gradle-plugin/src/main/kotlin/kotlinx/rpc/buf/BufExtensions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,25 @@ public open class BufGenerateExtension @Inject internal constructor(internal val
public fun comments(configure: Action<BufCommentsExtension>) {
configure.execute(comments)
}

/**
* Option to additionally generate nullable getters for optional message fields.
*
* Example:
* ```proto
* message Foo {
* optional string bar = 1;
* }
* ```
* ```kotlin
* val foo: Foo = ...
* val bar: String? = foo.barOrNull
* ```
*
* Default value: `false`.
*/
public val optionalFieldOrNullGetters: Property<Boolean> = project.objects.property<Boolean>()
.convention(false)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ public fun Project.configureLocalProtocGenDevelopmentDependency(

// init
rpcExtension().protoc {

buf {
generate {
optionalFieldOrNullGetters.set(true)
}
}

plugins {
kotlinMultiplatform {
local {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ internal open class DefaultProtocExtension @Inject constructor(

options.put("generateComments", buf.generate.comments.copyComments)
options.put("generateFileLevelComments", buf.generate.comments.includeFileLevelComments)
options.put("generateOptionalFieldOrNullGetters", buf.generate.optionalFieldOrNullGetters)
options.put("indentSize", buf.generate.indentSize)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class GrpcTimeoutTest : GrpcTestBase() {
proceed(it)
}
) {
val request = EchoRequest { message = "Echo"; timeout = 2u }
val request = EchoRequest { message = "Echo"; timeout = 200u }
it.withService<EchoService>().UnaryEcho(request)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ public class MsgFieldDelegate<T>(
@Suppress("UNCHECKED_CAST")
return value as T
}

/**
* Clears the value if it was set. This is used by the generated clear() function on message builders
* to clear the value from within the copy function body.
*/
public fun clearField(thisRef: InternalMessage) {
valueSet = false
value = null
presenceIdx?.let { thisRef.presenceMask[it] = false }
}
}

@InternalRpcApi
Expand Down
Loading
Loading