diff --git a/src/Kiota.Builder/Export/PublicAPIExportService.cs b/src/Kiota.Builder/Export/PublicAPIExportService.cs index 4ac9361e19..c4b2972b24 100644 --- a/src/Kiota.Builder/Export/PublicAPIExportService.cs +++ b/src/Kiota.Builder/Export/PublicAPIExportService.cs @@ -14,6 +14,7 @@ using Kiota.Builder.Writers.Php; using Kiota.Builder.Writers.Python; using Kiota.Builder.Writers.Ruby; +using Kiota.Builder.Writers.Rust; using Kiota.Builder.Writers.TypeScript; namespace Kiota.Builder.Export; @@ -123,6 +124,7 @@ private static ILanguageConventionService GetLanguageConventionServiceFromConfig GenerationLanguage.Go => new GoConventionService(), GenerationLanguage.Ruby => new RubyConventionService(), GenerationLanguage.Dart => new DartConventionService(), + GenerationLanguage.Rust => new RustConventionService(), _ => throw new ArgumentOutOfRangeException(nameof(generationConfiguration), generationConfiguration.Language, null) }; } diff --git a/src/Kiota.Builder/GenerationLanguage.cs b/src/Kiota.Builder/GenerationLanguage.cs index 7f2696e7e3..a4bdbd9473 100644 --- a/src/Kiota.Builder/GenerationLanguage.cs +++ b/src/Kiota.Builder/GenerationLanguage.cs @@ -10,5 +10,6 @@ public enum GenerationLanguage Go, Ruby, Dart, + Rust, HTTP } diff --git a/src/Kiota.Builder/PathSegmenters/RustPathSegmenter.cs b/src/Kiota.Builder/PathSegmenters/RustPathSegmenter.cs new file mode 100644 index 0000000000..7c41166487 --- /dev/null +++ b/src/Kiota.Builder/PathSegmenters/RustPathSegmenter.cs @@ -0,0 +1,13 @@ +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.PathSegmenters; + +public class RustPathSegmenter(string rootPath, string clientNamespaceName) : CommonPathSegmenter(rootPath, clientNamespaceName) +{ + public override string FileSuffix => ".rs"; + + public override string NormalizeNamespaceSegment(string segmentName) => segmentName.ToSnakeCase(); + + public override string NormalizeFileName(CodeElement currentElement) => GetLastFileNameSegment(currentElement).ToSnakeCase(); +} diff --git a/src/Kiota.Builder/Refiners/ILanguageRefiner.cs b/src/Kiota.Builder/Refiners/ILanguageRefiner.cs index 096d7a52f4..5c7fe54b8a 100644 --- a/src/Kiota.Builder/Refiners/ILanguageRefiner.cs +++ b/src/Kiota.Builder/Refiners/ILanguageRefiner.cs @@ -41,6 +41,9 @@ public static async Task RefineAsync(GenerationConfiguration config, CodeNamespa case GenerationLanguage.Dart: await new DartRefiner(config).RefineAsync(generatedCode, cancellationToken).ConfigureAwait(false); break; + case GenerationLanguage.Rust: + await new RustRefiner(config).RefineAsync(generatedCode, cancellationToken).ConfigureAwait(false); + break; } } } diff --git a/src/Kiota.Builder/Refiners/RustExceptionsReservedNamesProvider.cs b/src/Kiota.Builder/Refiners/RustExceptionsReservedNamesProvider.cs new file mode 100644 index 0000000000..e54fe854de --- /dev/null +++ b/src/Kiota.Builder/Refiners/RustExceptionsReservedNamesProvider.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; + +namespace Kiota.Builder.Refiners; + +public class RustExceptionsReservedNamesProvider : IReservedNamesProvider +{ + private readonly Lazy> _reservedNames = new(static () => new(StringComparer.Ordinal) + { + "to_string", + "fmt", + "source", + }); + public HashSet ReservedNames => _reservedNames.Value; +} diff --git a/src/Kiota.Builder/Refiners/RustRefiner.cs b/src/Kiota.Builder/Refiners/RustRefiner.cs new file mode 100644 index 0000000000..689d41d784 --- /dev/null +++ b/src/Kiota.Builder/Refiners/RustRefiner.cs @@ -0,0 +1,368 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Configuration; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Refiners; + +public class RustRefiner : CommonLanguageRefiner, ILanguageRefiner +{ + private const string AbstractionsNamespaceName = "kiota_abstractions"; + private const string SerializationNamespaceName = "kiota_serialization"; + private const string MultipartBodyClassName = "MultipartBody"; + + protected static readonly AdditionalUsingEvaluator[] defaultUsingEvaluators = { + new (static x => x is CodeProperty prop && prop.IsOfKind(CodePropertyKind.RequestAdapter), + AbstractionsNamespaceName, "RequestAdapter"), + new (static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.RequestGenerator), + AbstractionsNamespaceName, "Method", "RequestInformation", "RequestOption"), + new (static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.Serializer), + AbstractionsNamespaceName, "SerializationWriter"), + new (static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.Deserializer), + AbstractionsNamespaceName, "ParseNode"), + new (static x => x is CodeClass @class && @class.IsOfKind(CodeClassKind.Model), + AbstractionsNamespaceName, "Parsable"), + new (static x => x is CodeClass @class && @class.IsOfKind(CodeClassKind.Model) && @class.Properties.Any(p => p.IsOfKind(CodePropertyKind.AdditionalData)), + AbstractionsNamespaceName, "AdditionalDataHolder"), + new (static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.RequestExecutor), + AbstractionsNamespaceName, "Parsable"), + new (static x => x is CodeProperty prop && prop.IsOfKind(CodePropertyKind.Headers), + AbstractionsNamespaceName, "RequestHeaders"), + new (static x => x is CodeProperty prop && prop.IsOfKind(CodePropertyKind.Custom) && prop.Type.Name.Equals(KiotaBuilder.UntypedNodeName, StringComparison.OrdinalIgnoreCase), + AbstractionsNamespaceName, KiotaBuilder.UntypedNodeName), + new (static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.RequestExecutor, CodeMethodKind.RequestGenerator) && method.Parameters.Any(static y => y.IsOfKind(CodeParameterKind.RequestBody) && y.Type.Name.Equals(MultipartBodyClassName, StringComparison.OrdinalIgnoreCase)), + AbstractionsNamespaceName, MultipartBodyClassName), + }; + + public RustRefiner(GenerationConfiguration configuration) : base(configuration) { } + + public override Task RefineAsync(CodeNamespace generatedCode, CancellationToken cancellationToken) + { + return Task.Run(() => + { + cancellationToken.ThrowIfCancellationRequested(); + var defaultConfiguration = new GenerationConfiguration(); + + ConvertUnionTypesToWrapper(generatedCode, + _configuration.UsesBackingStore, + static s => s.ToSnakeCase(), + false); + ReplaceIndexersByMethodsWithParameter(generatedCode, + false, + static x => $"by_{x.ToSnakeCase()}", + static x => x.ToSnakeCase(), + GenerationLanguage.Rust); + + // Replace navigation properties by methods early, before CorrectCoreType/ReplacePropertyNames + ReplaceRequestBuilderPropertiesByMethods(generatedCode); + + CorrectCommonNames(generatedCode); + var reservedNamesProvider = new RustReservedNamesProvider(); + cancellationToken.ThrowIfCancellationRequested(); + + CorrectNames(generatedCode, s => + { + if (s.Contains('_', StringComparison.OrdinalIgnoreCase) && + s.ToPascalCase(UnderscoreArray) is string refinedName && + !reservedNamesProvider.ReservedNames.Contains(s) && + !reservedNamesProvider.ReservedNames.Contains(refinedName)) + return refinedName; + return s; + }); + + CorrectCoreType(generatedCode, CorrectMethodType, CorrectPropertyType, CorrectImplements); + + ReplacePropertyNames(generatedCode, + [ + CodePropertyKind.Custom, + CodePropertyKind.AdditionalData, + CodePropertyKind.QueryParameter, + ], + static s => s.ToSnakeCase()); + + // Rust has no inheritance, so we do NOT call MoveRequestBuilderPropertiesToBaseType. + // Request builder properties (request_adapter, path_parameters, url_template) stay on the struct. + + RemoveRequestConfigurationClasses(generatedCode, + new CodeUsing + { + Name = "RequestConfiguration", + Declaration = new CodeType + { + Name = AbstractionsNamespaceName, + IsExternal = true + } + }, new CodeType + { + Name = "DefaultQueryParameters", + IsExternal = true, + }); + + MoveInnerClassesToNamespace(generatedCode); + + AddDefaultImports(generatedCode, defaultUsingEvaluators); + AddPropertiesAndMethodTypesImports(generatedCode, true, true, true); + AddParsableImplementsForModelClasses(generatedCode, "Parsable"); + AddConstructorsForDefaultValues(generatedCode, true); + cancellationToken.ThrowIfCancellationRequested(); + + AddDiscriminatorMappingsUsingsToParentClasses(generatedCode, "ParseNode", addUsings: true, includeParentNamespace: true); + + ReplaceReservedNames(generatedCode, reservedNamesProvider, x => $"r#{x}"); + ReplaceReservedExceptionPropertyNames( + generatedCode, + new RustExceptionsReservedNamesProvider(), + static x => $"{x}_escaped" + ); + + ReplaceDefaultSerializationModules( + generatedCode, + defaultConfiguration.Serializers, + new(StringComparer.OrdinalIgnoreCase) { + $"{SerializationNamespaceName}_json.JsonSerializationWriterFactory", + } + ); + ReplaceDefaultDeserializationModules( + generatedCode, + defaultConfiguration.Deserializers, + new(StringComparer.OrdinalIgnoreCase) { + $"{SerializationNamespaceName}_json.JsonParseNodeFactory", + } + ); + AddSerializationModulesImport(generatedCode, + [$"{AbstractionsNamespaceName}.ApiClientBuilder", + $"{AbstractionsNamespaceName}.SerializationWriterFactoryRegistry"], + [$"{AbstractionsNamespaceName}.ParseNodeFactoryRegistry"]); + cancellationToken.ThrowIfCancellationRequested(); + + AddParentClassToErrorClasses( + generatedCode, + "ApiError", + AbstractionsNamespaceName + ); + DeduplicateErrorMappings(generatedCode); + RemoveCancellationParameter(generatedCode); + DisambiguatePropertiesWithClassNames(generatedCode); + RemoveMethodByKind(generatedCode, CodeMethodKind.RawUrlBuilder); + }, cancellationToken); + } + + private static void CorrectCommonNames(CodeElement currentElement) + { + if (currentElement is CodeMethod m && + currentElement.Parent is CodeClass parentClass) + { + // Rust uses snake_case for method names + var snakeName = m.Name.ToSnakeCase(); + if (!snakeName.Equals(m.Name, StringComparison.Ordinal)) + parentClass.RenameChildElement(m.Name, snakeName); + // Rust uses PascalCase for type names + parentClass.Name = parentClass.Name.ToFirstCharacterUpperCase(); + } + else if (currentElement is CodeIndexer i) + { + i.IndexParameter.Name = i.IndexParameter.Name.ToSnakeCase(); + } + else if (currentElement is CodeEnum e) + { + // Rust enum variants are PascalCase + foreach (var option in e.Options.ToList()) + { + option.Name = option.Name.ToFirstCharacterUpperCase(); + } + } + CrawlTree(currentElement, element => CorrectCommonNames(element)); + } + + private static void CorrectMethodType(CodeMethod currentMethod) + { + if (currentMethod.IsOfKind(CodeMethodKind.Serializer)) + currentMethod.Parameters.Where(x => x.IsOfKind(CodeParameterKind.Serializer)).ToList().ForEach(x => + { + x.Optional = false; + x.Type.IsNullable = false; + if (x.Type.Name.StartsWith('I')) + x.Type.Name = x.Type.Name[1..]; + }); + else if (currentMethod.IsOfKind(CodeMethodKind.Deserializer)) + { + currentMethod.ReturnType.Name = "HashMap>"; + currentMethod.Name = "get_field_deserializers"; + } + else if (currentMethod.IsOfKind(CodeMethodKind.RawUrlConstructor, CodeMethodKind.ClientConstructor)) + { + currentMethod.Parameters.Where(x => x.IsOfKind(CodeParameterKind.RequestAdapter, CodeParameterKind.BackingStore)) + .Where(x => x.Type.Name.StartsWith('I')) + .ToList() + .ForEach(x => x.Type.Name = x.Type.Name[1..]); + } + CorrectCoreTypes(currentMethod.Parent as CodeClass, DateTypesReplacements, types: currentMethod.Parameters + .Select(static x => x.Type) + .Union(new[] { currentMethod.ReturnType }) + .ToArray()); + currentMethod.Parameters.ToList().ForEach(static x => x.Name = x.Name.ToSnakeCase()); + } + + private static void CorrectPropertyType(CodeProperty currentProperty) + { + ArgumentNullException.ThrowIfNull(currentProperty); + + if (currentProperty.IsOfKind(CodePropertyKind.Options)) + currentProperty.DefaultValue = "Vec::new()"; + else if (currentProperty.IsOfKind(CodePropertyKind.Headers)) + currentProperty.DefaultValue = "RequestHeaders::new()"; + else if (currentProperty.IsOfKind(CodePropertyKind.RequestAdapter)) + { + currentProperty.Type.Name = "std::sync::Arc"; + currentProperty.Type.IsNullable = false; + } + else if (currentProperty.IsOfKind(CodePropertyKind.BackingStore)) + { + currentProperty.Type.Name = currentProperty.Type.Name.TrimStart('I'); + currentProperty.Name = currentProperty.Name.ToSnakeCase(); + } + else if (currentProperty.IsOfKind(CodePropertyKind.AdditionalData)) + { + currentProperty.Type.Name = "HashMap"; + currentProperty.DefaultValue = "HashMap::new()"; + currentProperty.Name = currentProperty.Name.ToSnakeCase(); + } + else if (currentProperty.IsOfKind(CodePropertyKind.UrlTemplate)) + { + currentProperty.Type.IsNullable = false; + currentProperty.Type.Name = "String"; + } + else if (currentProperty.IsOfKind(CodePropertyKind.PathParameters)) + { + currentProperty.Type.IsNullable = false; + currentProperty.Type.Name = "HashMap"; + if (!string.IsNullOrEmpty(currentProperty.DefaultValue)) + currentProperty.DefaultValue = "HashMap::new()"; + } + else + { + if (!currentProperty.IsNameEscaped) + currentProperty.SerializationName = currentProperty.Name; + currentProperty.Name = currentProperty.Name.ToSnakeCase(); + } + CorrectCoreTypes(currentProperty.Parent as CodeClass, DateTypesReplacements, types: currentProperty.Type); + } + + private static void CorrectImplements(ProprietableBlockDeclaration block) + { + block.Implements.Where(x => + x.Name.StartsWith('I') && ( + x.Name.Equals("IAdditionalDataHolder", StringComparison.OrdinalIgnoreCase) || + x.Name.Equals("IBackedModel", StringComparison.OrdinalIgnoreCase) + )).ToList().ForEach(x => x.Name = x.Name[1..]); + } + + protected static void DisambiguatePropertiesWithClassNames(CodeElement currentElement) + { + if (currentElement is CodeClass currentClass) + { + var sameNameProperty = currentClass.Properties + .FirstOrDefault(x => x.Name.Equals(currentClass.Name, StringComparison.OrdinalIgnoreCase)); + if (sameNameProperty != null) + { + currentClass.RemoveChildElement(sameNameProperty); + if (string.IsNullOrEmpty(sameNameProperty.SerializationName)) + sameNameProperty.SerializationName = sameNameProperty.Name; + sameNameProperty.Name = $"{sameNameProperty.Name}_prop"; + currentClass.AddProperty(sameNameProperty); + } + } + CrawlTree(currentElement, DisambiguatePropertiesWithClassNames); + } + + private static void MoveInnerClassesToNamespace(CodeElement currentElement) + { + if (currentElement is CodeClass parentClass) + { + var innerClasses = parentClass.InnerClasses.ToArray(); + if (innerClasses.Length > 0 && parentClass.Parent is CodeNamespace ns) + { + foreach (var innerClass in innerClasses) + { + parentClass.RemoveChildElement(innerClass); + ns.AddClass(innerClass); + // Add a using so the parent class can still reference the moved type + parentClass.AddUsing(new CodeUsing + { + Name = innerClass.Name, + Declaration = new CodeType + { + Name = innerClass.Name, + TypeDefinition = innerClass, + IsExternal = false, + } + }); + } + } + } + CrawlTree(currentElement, MoveInnerClassesToNamespace); + } + + private static void ReplaceRequestBuilderPropertiesByMethods(CodeElement currentElement) + { + if (currentElement is CodeProperty currentProperty && + currentProperty.IsOfKind(CodePropertyKind.RequestBuilder) && + currentElement.Parent is CodeClass parentClass) + { + parentClass.RemoveChildElement(currentProperty); + currentProperty.Type.IsNullable = false; + parentClass.AddMethod(new CodeMethod + { + Name = currentProperty.Name, + ReturnType = currentProperty.Type, + Access = AccessModifier.Public, + Documentation = (CodeDocumentation)currentProperty.Documentation.Clone(), + IsAsync = false, + Kind = CodeMethodKind.RequestBuilderBackwardCompatibility, + }); + } + CrawlTree(currentElement, ReplaceRequestBuilderPropertiesByMethods); + } + + private static readonly Dictionary DateTypesReplacements = new(StringComparer.OrdinalIgnoreCase) { + {"TimeSpan", ("chrono::Duration", new CodeUsing { + Name = "chrono", + Declaration = new CodeType { + Name = "chrono", + IsExternal = true, + }, + })}, + {"DateTimeOffset", ("chrono::DateTime", new CodeUsing { + Name = "chrono", + Declaration = new CodeType { + Name = "chrono", + IsExternal = true, + }, + })}, + {"DateOnly", ("chrono::NaiveDate", new CodeUsing { + Name = "chrono", + Declaration = new CodeType { + Name = "chrono", + IsExternal = true, + }, + })}, + {"TimeOnly", ("chrono::NaiveTime", new CodeUsing { + Name = "chrono", + Declaration = new CodeType { + Name = "chrono", + IsExternal = true, + }, + })}, + {"Guid", ("uuid::Uuid", new CodeUsing { + Name = "uuid", + Declaration = new CodeType { + Name = "uuid", + IsExternal = true, + }, + })}, + }; +} diff --git a/src/Kiota.Builder/Refiners/RustReservedNamesProvider.cs b/src/Kiota.Builder/Refiners/RustReservedNamesProvider.cs new file mode 100644 index 0000000000..d0e5839317 --- /dev/null +++ b/src/Kiota.Builder/Refiners/RustReservedNamesProvider.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; + +namespace Kiota.Builder.Refiners; + +public class RustReservedNamesProvider : IReservedNamesProvider +{ + private readonly Lazy> _reservedNames = new(static () => new(StringComparer.Ordinal) { + // Strict keywords + "as", + "async", + "await", + "break", + "const", + "continue", + "crate", + "dyn", + "else", + "enum", + "extern", + "false", + "fn", + "for", + "if", + "impl", + "in", + "let", + "loop", + "match", + "mod", + "move", + "mut", + "pub", + "ref", + "return", + "self", + "Self", + "static", + "struct", + "super", + "trait", + "true", + "type", + "unsafe", + "use", + "where", + "while", + // Reserved for future use + "abstract", + "become", + "box", + "do", + "final", + "macro", + "override", + "priv", + "try", + "typeof", + "unsized", + "virtual", + "yield", + // Weak keywords used in certain contexts + "union", + "dyn", + // Common standard library types/traits that could collide + "String", + "Vec", + "Box", + "Option", + "Result", + "HashMap", + "Clone", + "Default", + "Display", + "Debug", + "Iterator", + "From", + "Into", + "Error", + // Kiota base type names + "BaseRequestBuilder", + }); + public HashSet ReservedNames => _reservedNames.Value; +} diff --git a/src/Kiota.Builder/Writers/LanguageWriter.cs b/src/Kiota.Builder/Writers/LanguageWriter.cs index 22313267dd..bf9d4cdda0 100644 --- a/src/Kiota.Builder/Writers/LanguageWriter.cs +++ b/src/Kiota.Builder/Writers/LanguageWriter.cs @@ -13,6 +13,7 @@ using Kiota.Builder.Writers.Php; using Kiota.Builder.Writers.Python; using Kiota.Builder.Writers.Ruby; +using Kiota.Builder.Writers.Rust; using Kiota.Builder.Writers.TypeScript; namespace Kiota.Builder.Writers; @@ -190,6 +191,7 @@ public static LanguageWriter GetLanguageWriter(GenerationLanguage language, stri GenerationLanguage.Python => new PythonWriter(outputPath, clientNamespaceName, usesBackingStore), GenerationLanguage.Go => new GoWriter(outputPath, clientNamespaceName, excludeBackwardCompatible), GenerationLanguage.Dart => new DartWriter(outputPath, clientNamespaceName), + GenerationLanguage.Rust => new RustWriter(outputPath, clientNamespaceName), GenerationLanguage.HTTP => new HttpWriter(outputPath, clientNamespaceName), _ => throw new InvalidEnumArgumentException($"{language} language currently not supported."), }; diff --git a/src/Kiota.Builder/Writers/Rust/CodeBlockEndWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeBlockEndWriter.cs new file mode 100644 index 0000000000..e093326cc2 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeBlockEndWriter.cs @@ -0,0 +1,13 @@ +using System; +using Kiota.Builder.CodeDOM; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeBlockEndWriter : ICodeElementWriter +{ + public void WriteCodeElement(BlockEnd codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(writer); + writer.CloseBlock(); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeClassDeclarationWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeClassDeclarationWriter.cs new file mode 100644 index 0000000000..4856afa84b --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeClassDeclarationWriter.cs @@ -0,0 +1,122 @@ +using System; +using System.Linq; +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; +using Kiota.Builder.PathSegmenters; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeClassDeclarationWriter : BaseElementWriter +{ + private readonly RelativeImportManager relativeImportManager; + private readonly string clientNamespaceName; + + public CodeClassDeclarationWriter(RustConventionService conventionService, string clientNamespaceName, RustPathSegmenter pathSegmenter) : base(conventionService) + { + ArgumentNullException.ThrowIfNull(pathSegmenter); + this.clientNamespaceName = clientNamespaceName; + relativeImportManager = new RelativeImportManager(clientNamespaceName, '.', (ns, element) => pathSegmenter.NormalizeFileName(element)); + } + + public override void WriteCodeElement(ClassDeclaration codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + + if (codeElement.Parent is not CodeClass parentClass) + throw new InvalidOperationException($"The provided code element {codeElement.Name} doesn't have a parent of type {nameof(CodeClass)}"); + + conventions.WriteAutogeneratedMessage(writer); + writer.WriteLine(); + + // Write use statements for external dependencies + if (codeElement.Parent?.Parent is CodeNamespace) + { + // Standard imports for generated Rust files + writer.WriteLine("use std::collections::HashMap;"); + + foreach (var externalUsing in codeElement.Usings + .Where(static x => x.IsExternal) + .DistinctBy(static x => x.Declaration!.Name, StringComparer.Ordinal) + .OrderBy(static x => x.Declaration!.Name, StringComparer.Ordinal)) + { + var declName = externalUsing.Declaration!.Name; + if (declName.Equals("kiota_abstractions", StringComparison.Ordinal)) + writer.WriteLine("use kiota_abstractions::*;"); + else if (declName.Equals("serde", StringComparison.Ordinal)) + writer.WriteLine("use serde::{Serialize, Deserialize};"); + else if (declName.Equals("serde_json", StringComparison.Ordinal)) + writer.WriteLine("use serde_json;"); + else if (declName.StartsWith("kiota_serialization_", StringComparison.Ordinal)) + writer.WriteLine($"use {declName}::*;"); + else + writer.WriteLine($"use {declName};"); + } + + foreach (var internalUsing in codeElement.Usings + .Where(static x => !x.IsExternal && x.Declaration?.TypeDefinition != null) + .DistinctBy(static x => x.Declaration?.TypeDefinition?.Name, StringComparer.OrdinalIgnoreCase) + .OrderBy(static x => x.Declaration?.TypeDefinition?.Name, StringComparer.Ordinal)) + { + var typeName = internalUsing.Declaration?.TypeDefinition?.Name?.ToFirstCharacterUpperCase(); + if (!string.IsNullOrEmpty(typeName)) + { + var moduleFileName = typeName.ToSnakeCase(); + // Build the full module path from the namespace hierarchy + var typeNamespace = internalUsing.Declaration?.TypeDefinition?.Parent as CodeNamespace; + var currentNamespace = codeElement.Parent?.Parent as CodeNamespace; + if (typeNamespace != null && currentNamespace != null) + { + var nsName = typeNamespace.Name; + var rootName = clientNamespaceName; + var relativePath = nsName.StartsWith(rootName + ".", StringComparison.Ordinal) + ? nsName[(rootName.Length + 1)..] + : (nsName.Equals(rootName, StringComparison.Ordinal) ? string.Empty : nsName); + if (string.IsNullOrEmpty(relativePath)) + { + writer.WriteLine($"use crate::{moduleFileName}::{typeName};"); + } + else + { + var moduleParts = relativePath.Split('.').Select(p => p.ToSnakeCase()); + var modulePath = string.Join("::", moduleParts); + writer.WriteLine($"use crate::{modulePath}::{moduleFileName}::{typeName};"); + } + } + } + } + + writer.WriteLine(); + } + + conventions.WriteLongDescription(parentClass, writer); + conventions.WriteDeprecationAttribute(parentClass, writer); + + var isModel = parentClass.IsOfKind(CodeClassKind.Model); + var isRequestBuilder = parentClass.IsOfKind(CodeClassKind.RequestBuilder); + var isQueryParameters = parentClass.IsOfKind(CodeClassKind.QueryParameters); + + if (isModel) + { + // Model classes get serde derives + var derives = new[] { "Debug", "Clone", "Default", "Serialize", "Deserialize" }; + writer.WriteLine($"#[derive({string.Join(", ", derives)})]"); + } + else if (isQueryParameters) + { + // Query parameter classes need Default and Clone + writer.WriteLine("#[derive(Debug, Clone, Default)]"); + } + else if (isRequestBuilder) + { + // Request builders have Box fields that can't derive Clone/Debug + } + else + { + writer.WriteLine("#[derive(Debug, Clone)]"); + } + + var structName = codeElement.Name.ToFirstCharacterUpperCase(); + writer.StartBlock($"pub struct {structName} {{"); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeEnumWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeEnumWriter.cs new file mode 100644 index 0000000000..bc067724ba --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeEnumWriter.cs @@ -0,0 +1,40 @@ +using System; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeEnumWriter : BaseElementWriter +{ + public CodeEnumWriter(RustConventionService conventionService) : base(conventionService) { } + + public override void WriteCodeElement(CodeEnum codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + if (!codeElement.Options.Any()) + return; + + conventions.WriteAutogeneratedMessage(writer); + writer.WriteLine("use serde::{Serialize, Deserialize};"); + writer.WriteLine(); + + conventions.WriteShortDescription(codeElement, writer); + conventions.WriteDeprecationAttribute(codeElement, writer); + writer.WriteLine("#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]"); + writer.StartBlock($"pub enum {codeElement.Name} {{"); + + foreach (var option in codeElement.Options) + { + conventions.WriteShortDescription(option, writer); + var serializationName = option.SerializationName; + if (!string.IsNullOrEmpty(serializationName) && !serializationName.Equals(option.Name, StringComparison.Ordinal)) + { + writer.WriteLine($"#[serde(rename = \"{serializationName}\")]"); + } + writer.WriteLine($"{option.Name.ToFirstCharacterUpperCase()},"); + } + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeMethodWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeMethodWriter.cs new file mode 100644 index 0000000000..6a61693e59 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeMethodWriter.cs @@ -0,0 +1,635 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; +using Kiota.Builder.OrderComparers; +using static Kiota.Builder.CodeDOM.CodeTypeBase; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeMethodWriter : BaseElementWriter +{ + private readonly HashSet classesWithImplBlockOpened = new(StringComparer.Ordinal); + + public CodeMethodWriter(RustConventionService conventionService) : base(conventionService) { } + + internal bool HasImplBlockBeenOpened(string className) => classesWithImplBlockOpened.Contains(className); + + public override void WriteCodeElement(CodeMethod codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + if (codeElement.ReturnType == null) throw new InvalidOperationException($"{nameof(codeElement.ReturnType)} should not be null"); + ArgumentNullException.ThrowIfNull(writer); + if (codeElement.Parent is not CodeClass parentClass) throw new InvalidOperationException("the parent of a method should be a class"); + + var structName = parentClass.Name.ToFirstCharacterUpperCase(); + + // Close struct block and open impl block on first method + if (classesWithImplBlockOpened.Add(structName)) + { + writer.CloseBlock(); // close pub struct { ... } + writer.WriteLine(); + writer.StartBlock($"impl {structName} {{"); + } + + var returnType = conventions.GetTypeString(codeElement.ReturnType, codeElement); + var inherits = parentClass.StartBlock.Inherits != null && !parentClass.IsErrorDefinition; + var isVoid = conventions.VoidTypeName.Equals(returnType, StringComparison.OrdinalIgnoreCase) || returnType == "()"; + WriteMethodDocumentation(codeElement, writer); + WriteMethodPrototype(codeElement, parentClass, writer, returnType, inherits, isVoid); + + HandleMethodKind(codeElement, writer, inherits, parentClass, isVoid); + writer.CloseBlock(); + } + + protected virtual void HandleMethodKind(CodeMethod codeElement, LanguageWriter writer, bool doesInherit, CodeClass parentClass, bool isVoid) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + ArgumentNullException.ThrowIfNull(parentClass); + var returnType = conventions.GetTypeString(codeElement.ReturnType, codeElement); + var returnTypeWithoutCollectionInformation = conventions.GetTypeString(codeElement.ReturnType, codeElement, false); + var requestBodyParam = codeElement.Parameters.OfKind(CodeParameterKind.RequestBody); + var requestConfig = codeElement.Parameters.OfKind(CodeParameterKind.RequestConfiguration); + var requestContentType = codeElement.Parameters.OfKind(CodeParameterKind.RequestBodyContentType); + var requestParams = new RequestParams(requestBodyParam, requestConfig, requestContentType); + + switch (codeElement.Kind) + { + case CodeMethodKind.Serializer: + WriteSerializerBody(doesInherit, codeElement, parentClass, writer); + break; + case CodeMethodKind.RequestGenerator: + WriteRequestGeneratorBody(codeElement, requestParams, parentClass, writer); + break; + case CodeMethodKind.RequestExecutor: + WriteRequestExecutorBody(codeElement, requestParams, parentClass, isVoid, returnTypeWithoutCollectionInformation, writer); + break; + case CodeMethodKind.Deserializer: + WriteDeserializerBody(doesInherit, codeElement, parentClass, writer); + break; + case CodeMethodKind.ClientConstructor: + WriteConstructorBody(parentClass, codeElement, writer); + WriteApiConstructorBody(parentClass, codeElement, writer); + break; + case CodeMethodKind.Constructor: + case CodeMethodKind.RawUrlConstructor: + WriteConstructorBody(parentClass, codeElement, writer); + break; + case CodeMethodKind.IndexerBackwardCompatibility: + case CodeMethodKind.RequestBuilderWithParameters: + WriteRequestBuilderBody(parentClass, codeElement, writer); + break; + case CodeMethodKind.QueryParametersMapper: + WriteQueryParametersBody(parentClass, writer); + break; + case CodeMethodKind.Getter: + WriteGetterBody(codeElement, writer, parentClass); + break; + case CodeMethodKind.Setter: + WriteSetterBody(codeElement, writer); + break; + case CodeMethodKind.RequestBuilderBackwardCompatibility: + WriteRequestBuilderBody(parentClass, codeElement, writer); + break; + case CodeMethodKind.ErrorMessageOverride: + throw new InvalidOperationException("ErrorMessageOverride is not supported as the error message is implemented by Display trait."); + case CodeMethodKind.CommandBuilder: + throw new InvalidOperationException("CommandBuilder methods are not implemented for Rust."); + case CodeMethodKind.Factory: + WriteFactoryMethodBody(codeElement, parentClass, writer); + break; + case CodeMethodKind.ComposedTypeMarker: + throw new InvalidOperationException("ComposedTypeMarker is not required for Rust."); + default: + writer.WriteLine("todo!()"); + break; + } + } + + private void WriteFactoryMethodBody(CodeMethod codeElement, CodeClass parentClass, LanguageWriter writer) + { + var parseNodeParameter = codeElement.Parameters.OfKind(CodeParameterKind.ParseNode) ?? throw new InvalidOperationException("Factory method should have a ParseNode parameter"); + + if (parentClass.DiscriminatorInformation.ShouldWriteDiscriminatorForInheritedType) + { + var discriminatorPropertyName = parentClass.DiscriminatorInformation.DiscriminatorPropertyName; + writer.WriteLine($"let mapping_value = {parseNodeParameter.Name.ToSnakeCase()}.get_child_node(\"{discriminatorPropertyName}\").and_then(|n| n.get_string_value());"); + writer.StartBlock("match mapping_value.as_deref() {"); + foreach (var mappedType in parentClass.DiscriminatorInformation.DiscriminatorMappings) + { + writer.WriteLine($"Some(\"{mappedType.Key}\") => {conventions.GetTypeString(mappedType.Value.AllTypes.First(), codeElement)}::default(),"); + } + writer.WriteLine($"_ => {parentClass.Name}::default(),"); + writer.CloseBlock(); + } + else + { + writer.WriteLine($"{parentClass.Name}::default()"); + } + } + + private void WriteRequestBuilderBody(CodeClass parentClass, CodeMethod codeElement, LanguageWriter writer) + { + var importSymbol = conventions.GetTypeString(codeElement.ReturnType, parentClass); + conventions.AddRequestBuilderBody(parentClass, importSymbol, writer, prefix: "", pathParameters: codeElement.Parameters.Where(static x => x.IsOfKind(CodeParameterKind.Path)), customParameters: codeElement.Parameters.Where(static x => x.IsOfKind(CodeParameterKind.Custom))); + } + + private void WriteGetterBody(CodeMethod codeElement, LanguageWriter writer, CodeClass parentClass) + { + var accessedProperty = codeElement.AccessedProperty; + if (accessedProperty?.IsOfKind(CodePropertyKind.RequestBuilder) == true) + { + // Navigation property: create and return child request builder on the fly + var returnType = conventions.GetTypeString(codeElement.ReturnType, parentClass); + conventions.AddRequestBuilderBody(parentClass, returnType, writer, prefix: ""); + } + else + { + // Regular field getter: return the field value + var fieldName = accessedProperty?.Name?.ToSnakeCase() ?? codeElement.Name.ToSnakeCase(); + writer.WriteLine($"self.{fieldName}.clone()"); + } + } + + private static void WriteSetterBody(CodeMethod codeElement, LanguageWriter writer) + { + var fieldName = codeElement.AccessedProperty?.Name?.ToSnakeCase() ?? codeElement.Name.ToSnakeCase(); + var paramName = codeElement.Parameters.FirstOrDefault()?.Name?.ToSnakeCase() ?? "value"; + writer.WriteLine($"self.{fieldName} = {paramName};"); + } + + private static void WriteApiConstructorBody(CodeClass parentClass, CodeMethod method, LanguageWriter writer) + { + if (parentClass.GetPropertyOfKind(CodePropertyKind.RequestAdapter) is not CodeProperty requestAdapterProperty) return; + var pathParametersProperty = parentClass.GetPropertyOfKind(CodePropertyKind.PathParameters); + var backingStoreParameter = method.Parameters.OfKind(CodeParameterKind.BackingStore); + var requestAdapterPropertyName = requestAdapterProperty.Name.ToSnakeCase(); + + WriteSerializationRegistration(method.SerializerModules, writer, "register_default_serializer"); + WriteSerializationRegistration(method.DeserializerModules, writer, "register_default_deserializer"); + + if (!string.IsNullOrEmpty(method.BaseUrl)) + { + writer.StartBlock($"if instance.{requestAdapterPropertyName}.get_base_url().is_empty() {{"); + writer.WriteLine($"instance.{requestAdapterPropertyName}.set_base_url(\"{method.BaseUrl}\");"); + writer.CloseBlock(); + if (pathParametersProperty != null) + writer.WriteLine($"instance.{pathParametersProperty.Name.ToSnakeCase()}.insert(\"baseurl\".to_string(), instance.{requestAdapterPropertyName}.get_base_url().to_string());"); + } + + if (backingStoreParameter != null) + { + writer.StartBlock($"if let Some(ref store) = {backingStoreParameter.Name.ToSnakeCase()} {{"); + writer.WriteLine($"instance.{requestAdapterPropertyName}.enable_backing_store(store);"); + writer.CloseBlock(); + } + writer.WriteLine("instance"); + } + + private static void WriteSerializationRegistration(HashSet serializationClassNames, LanguageWriter writer, string methodName) + { + if (serializationClassNames != null) + foreach (var serializationClassName in serializationClassNames) + writer.WriteLine($"ApiClientBuilder::{methodName}::<{serializationClassName}>();"); + } + + private void WriteConstructorBody(CodeClass parentClass, CodeMethod currentMethod, LanguageWriter writer) + { + if (parentClass.IsOfKind(CodeClassKind.RequestBuilder)) + { + WriteRequestBuilderConstructorBody(parentClass, currentMethod, writer); + // ClientConstructor continues with WriteApiConstructorBody which writes its own "instance" return + if (!currentMethod.IsOfKind(CodeMethodKind.ClientConstructor)) + writer.WriteLine("instance"); + } + else if (parentClass.IsOfKind(CodeClassKind.Model)) + { + WriteModelConstructorBody(parentClass, currentMethod, writer); + } + } + + private void WriteRequestBuilderConstructorBody(CodeClass parentClass, CodeMethod currentMethod, LanguageWriter writer) + { + var pathParametersProp = parentClass.GetPropertyOfKind(CodePropertyKind.PathParameters); + var urlTemplateProp = parentClass.GetPropertyOfKind(CodePropertyKind.UrlTemplate); + var requestAdapterParam = currentMethod.Parameters.OfKind(CodeParameterKind.RequestAdapter); + var pathParametersParam = currentMethod.Parameters.OfKind(CodeParameterKind.PathParameters); + var rawUrlParam = currentMethod.Parameters.OfKind(CodeParameterKind.RawUrl); + + // Only use `mut` if we need to mutate `instance` after construction + var pathParameters = pathParametersProp != null && pathParametersParam != null + ? currentMethod.Parameters.Where(static x => x.IsOfKind(CodeParameterKind.Path)).ToArray() + : []; + var needsMut = pathParameters.Length > 0 || currentMethod.IsOfKind(CodeMethodKind.ClientConstructor); + var letBinding = needsMut ? "let mut instance" : "let instance"; + + writer.StartBlock($"{letBinding} = Self {{"); + if (requestAdapterParam != null) + writer.WriteLine($"request_adapter: {requestAdapterParam.Name.ToSnakeCase()},"); + if (urlTemplateProp != null && !string.IsNullOrEmpty(urlTemplateProp.DefaultValue)) + writer.WriteLine($"url_template: {urlTemplateProp.DefaultValue}.to_string(),"); + if (pathParametersProp != null) + { + if (pathParametersParam != null) + writer.WriteLine($"path_parameters: {pathParametersParam.Name.ToSnakeCase()},"); + else if (rawUrlParam != null) + { + var rawUrlRef = rawUrlParam.Optional || rawUrlParam.Type.IsNullable + ? $"{rawUrlParam.Name.ToSnakeCase()}.unwrap_or_default()" + : rawUrlParam.Name.ToSnakeCase(); + writer.WriteLine("path_parameters: {"); + writer.IncreaseIndent(); + writer.WriteLine("let mut m = std::collections::HashMap::new();"); + writer.WriteLine($"m.insert(RequestInformation::RAW_URL_KEY.to_string(), {rawUrlRef});"); + writer.WriteLine("m"); + writer.DecreaseIndent(); + writer.WriteLine("},"); + } + else + writer.WriteLine("path_parameters: std::collections::HashMap::new(),"); + } + writer.CloseBlock("};"); + + // Handle path parameters + if (pathParameters.Length > 0) + { + foreach (var param in pathParameters) + { + var serialName = string.IsNullOrEmpty(param.SerializationName) ? param.Name : param.SerializationName; + writer.WriteLine($"instance.path_parameters.insert(\"{serialName}\".to_string(), {param.Name.ToSnakeCase()}.to_string());"); + } + } + } + + private void WriteModelConstructorBody(CodeClass parentClass, CodeMethod currentMethod, LanguageWriter writer) + { + var propWithDefaults = parentClass.Properties + .Where(static x => !string.IsNullOrEmpty(x.DefaultValue) && !x.IsOfKind(CodePropertyKind.UrlTemplate, CodePropertyKind.PathParameters, CodePropertyKind.BackingStore)) + .Where(static x => x.Type is not CodeType propType || propType.TypeDefinition is not CodeClass propertyClass || propertyClass.OriginalComposedType is null) + .OrderByDescending(static x => x.Kind) + .ThenBy(static x => x.Name).ToArray(); + + writer.StartBlock("Self {"); + foreach (var prop in propWithDefaults) + { + var defaultValue = GetDefaultValueForProperty(prop, currentMethod); + writer.WriteLine($"{prop.Name.ToSnakeCase()}: {defaultValue},"); + } + writer.WriteLine("..Default::default()"); + writer.CloseBlock(); + } + + private string GetDefaultValueForProperty(CodeProperty prop, CodeMethod method) + { + var defaultValue = prop.DefaultValue; + if (prop.Type is CodeType propertyType && propertyType.TypeDefinition is CodeEnum) + { + var enumTypeName = conventions.GetTypeString(prop.Type, method).TrimStart("Option<").TrimEnd('>').TrimEnd('?'); + return $"{enumTypeName}::{defaultValue.Trim('"').ToFirstCharacterUpperCase()}"; + } + if (prop.Type is CodeType pt && pt.Name.Equals("String", StringComparison.OrdinalIgnoreCase)) + { + return $"\"{defaultValue.Trim('"')}\".to_string()"; + } + return defaultValue; + } + + private void WriteDeserializerBody(bool shouldHide, CodeMethod codeElement, CodeClass parentClass, LanguageWriter writer) + { + var fieldToSerialize = parentClass.GetPropertiesOfKind(CodePropertyKind.Custom, CodePropertyKind.ErrorMessageOverride).ToArray(); + writer.WriteLine("let mut map: HashMap> = HashMap::new();"); + + if (fieldToSerialize.Length != 0) + { + foreach (var prop in fieldToSerialize + .Where(x => !x.ExistsInBaseType && !conventions.ErrorClassPropertyExistsInSuperClass(x)) + .OrderBy(static x => x.Name)) + { + var propName = prop.Name.ToSnakeCase(); + var isCollection = prop.Type.CollectionKind != CodeTypeCollectionKind.None; + var isEnum = prop.Type is CodeType ect && ect.TypeDefinition is CodeEnum; + var isComplex = prop.Type is CodeType cct && cct.TypeDefinition is CodeClass; + + if (isComplex || (isCollection && !IsPrimitiveCollection(prop.Type)) || isEnum) + { + // Complex types, non-primitive collections, enums: use serde via raw value + writer.WriteLine($"map.insert(\"{prop.WireName}\".to_string(), Box::new(|node, obj| {{ obj.{propName} = node.get_raw_value().and_then(|v| serde_json::from_value(v).ok()); }}));"); + } + else if (isCollection) + { + // Primitive collections: use concrete collection methods, wrap in Some() + var deserMethod = GetDeserializationMethodName(prop.Type, codeElement); + writer.WriteLine($"map.insert(\"{prop.WireName}\".to_string(), Box::new(|node, obj| {{ let vals = node.{deserMethod}; obj.{propName} = if vals.is_empty() {{ None }} else {{ Some(vals) }}; }}));"); + } + else + { + var deserMethod = GetDeserializationMethodName(prop.Type, codeElement); + writer.WriteLine($"map.insert(\"{prop.WireName}\".to_string(), Box::new(|node, obj| {{ obj.{propName} = node.{deserMethod}; }}));"); + } + } + } + writer.WriteLine("map"); + } + + private static bool IsPrimitiveCollection(CodeTypeBase propType) + { + if (propType is CodeType ct && ct.CollectionKind != CodeTypeCollectionKind.None && ct.TypeDefinition == null) + return true; + return false; + } + + private void WriteSerializerBody(bool shouldHide, CodeMethod method, CodeClass parentClass, LanguageWriter writer) + { + foreach (var otherProp in parentClass + .GetPropertiesOfKind(CodePropertyKind.Custom, CodePropertyKind.ErrorMessageOverride) + .Where(x => !x.ExistsInBaseType && !x.ReadOnly && !conventions.ErrorClassPropertyExistsInSuperClass(x)) + .OrderBy(static x => x.Name)) + { + var propName = otherProp.Name.ToSnakeCase(); + var isCollection = otherProp.Type.CollectionKind != CodeTypeCollectionKind.None; + var isEnum = otherProp.Type is CodeType ect && ect.TypeDefinition is CodeEnum; + var isComplex = otherProp.Type is CodeType cct && cct.TypeDefinition is CodeClass; + + if (isEnum && !isCollection) + { + // Enum serialization via serde: convert enum to string using JSON serialization + writer.WriteLine($"writer.write_string_value(\"{otherProp.WireName}\", &self.{propName}.as_ref().and_then(|v| serde_json::to_value(v).ok()).and_then(|v| v.as_str().map(String::from)))?;"); + } + else if (isComplex || isCollection || isEnum) + { + // Complex objects, collections, and enum collections: serialize via serde to raw JSON + writer.StartBlock($"if let Some(ref val) = self.{propName} {{"); + writer.StartBlock($"if let Ok(json) = serde_json::to_value(val) {{"); + writer.WriteLine($"writer.write_raw_value(\"{otherProp.WireName}\", &json)?;"); + writer.CloseBlock(); + writer.CloseBlock(); + } + else + { + var serializationMethodName = GetSerializationMethodName(otherProp.Type, method); + writer.WriteLine($"writer.{serializationMethodName}(\"{otherProp.WireName}\", &self.{propName})?;"); + } + } + + if (parentClass.GetPropertyOfKind(CodePropertyKind.AdditionalData) is CodeProperty additionalDataProperty) + writer.WriteLine($"writer.write_additional_data(&self.{additionalDataProperty.Name.ToSnakeCase()})?;"); + writer.WriteLine("Ok(())"); + } + + protected void WriteRequestExecutorBody(CodeMethod codeElement, RequestParams requestParams, CodeClass parentClass, bool isVoid, string returnTypeWithoutCollectionInformation, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(requestParams); + ArgumentNullException.ThrowIfNull(parentClass); + ArgumentNullException.ThrowIfNull(writer); + if (codeElement.HttpMethod == null) throw new InvalidOperationException("http method cannot be null"); + + var generatorMethodName = parentClass + .Methods + .FirstOrDefault(x => x.IsOfKind(CodeMethodKind.RequestGenerator) && x.HttpMethod == codeElement.HttpMethod) + ?.Name + ?.ToSnakeCase(); + var parametersList = new CodeParameter?[] { requestParams.requestBody, requestParams.requestContentType, requestParams.requestConfiguration } + .Where(static x => x != null) + .Select(static x => x!.Name.ToSnakeCase()) + .Aggregate(static (x, y) => $"{x}, {y}"); + writer.WriteLine($"let request_info = self.{generatorMethodName}({parametersList})?;"); + + // Error mappings (currently simplified — Rust uses Result instead of exception types) + writer.WriteLine("let error_mapping: std::collections::HashMap Box + Send + Sync>> = std::collections::HashMap::new();"); + + var sendMethod = GetSendRequestMethodName(isVoid, codeElement, codeElement.ReturnType); + + if (isVoid) + { + writer.WriteLine($"self.request_adapter.{sendMethod}(request_info, error_mapping).await?;"); + writer.WriteLine("Ok(())"); + } + else + { + writer.WriteLine($"let response = self.request_adapter.{sendMethod}(request_info, error_mapping).await?;"); + if (sendMethod == "send_raw") + writer.WriteLine("Ok(response.and_then(|v| serde_json::from_value(v).ok()))"); + else if (sendMethod == "send_raw_collection") + writer.WriteLine("Ok(Some(response.into_iter().filter_map(|v| serde_json::from_value(v).ok()).collect()))"); + else + writer.WriteLine("Ok(response)"); + } + } + + private void WriteRequestGeneratorBody(CodeMethod codeElement, RequestParams requestParams, CodeClass currentClass, LanguageWriter writer) + { + if (codeElement.HttpMethod == null) throw new InvalidOperationException("http method cannot be null"); + if (currentClass.GetPropertyOfKind(CodePropertyKind.PathParameters) is not CodeProperty urlTemplateParamsProperty) throw new InvalidOperationException("path parameters property cannot be null"); + if (currentClass.GetPropertyOfKind(CodePropertyKind.UrlTemplate) is not CodeProperty urlTemplateProperty) throw new InvalidOperationException("url template property cannot be null"); + + var operationName = codeElement.HttpMethod.ToString()?.ToUpperInvariant(); + var urlTemplateValue = codeElement.HasUrlTemplateOverride ? $"\"{codeElement.UrlTemplateOverride}\".to_string()" : $"self.{urlTemplateProperty.Name.ToSnakeCase()}.clone()"; + writer.WriteLine($"let mut request_info = RequestInformation::new(Method::{operationName}, {urlTemplateValue}, self.{urlTemplateParamsProperty.Name.ToSnakeCase()}.clone());"); + + if (requestParams.requestConfiguration != null) + { + writer.StartBlock($"if let Some(config) = {requestParams.requestConfiguration.Name.ToSnakeCase()} {{"); + writer.WriteLine("request_info.add_request_configuration(&config);"); + writer.CloseBlock(); + } + + if (codeElement.ShouldAddAcceptHeader) + writer.WriteLine($"request_info.headers.add(\"Accept\", \"{codeElement.AcceptHeaderValue}\");"); + + if (requestParams.requestBody != null) + { + var bodyParamName = requestParams.requestBody.Name.ToSnakeCase(); + if (requestParams.requestBody.Type.Name.Equals(conventions.StreamTypeName, StringComparison.OrdinalIgnoreCase)) + { + if (requestParams.requestContentType is not null) + writer.WriteLine($"request_info.set_stream_content({bodyParamName}, \"{requestParams.requestContentType.Name}\");"); + else if (!string.IsNullOrEmpty(codeElement.RequestBodyContentType)) + writer.WriteLine($"request_info.set_stream_content({bodyParamName}, \"{codeElement.RequestBodyContentType}\");"); + } + else + { + // Use serde serialization for the body (handles both model and scalar types) + var isOptional = requestParams.requestBody.Optional || requestParams.requestBody.Type.IsNullable; + if (isOptional) + { + writer.StartBlock($"if let Some(ref body_val) = {bodyParamName} {{"); + writer.WriteLine($"request_info.content = Some(serde_json::to_vec(body_val)?);"); + writer.WriteLine($"request_info.content_type = Some(\"{codeElement.RequestBodyContentType}\".to_string());"); + writer.CloseBlock(); + } + else + { + writer.WriteLine($"request_info.content = Some(serde_json::to_vec(&{bodyParamName})?);"); + writer.WriteLine($"request_info.content_type = Some(\"{codeElement.RequestBodyContentType}\".to_string());"); + } + } + } + + writer.WriteLine("Ok(request_info)"); + } + + private void WriteQueryParametersBody(CodeClass parentClass, LanguageWriter writer) + { + writer.StartBlock("let mut map = std::collections::HashMap::new();"); + foreach (CodeProperty property in parentClass.Properties) + { + var key = property.IsNameEscaped ? property.SerializationName : property.Name; + var propName = property.Name.ToSnakeCase(); + writer.StartBlock($"if let Some(ref val) = self.{propName} {{"); + writer.WriteLine($"map.insert(\"{key}\".to_string(), val.to_string());"); + writer.CloseBlock(); + } + writer.WriteLine("map"); + } + + protected string GetSendRequestMethodName(bool isVoid, CodeElement currentElement, CodeTypeBase returnType) + { + ArgumentNullException.ThrowIfNull(returnType); + var returnTypeName = conventions.GetTypeString(returnType, currentElement, false); + var isStream = conventions.StreamTypeName.Equals(returnTypeName, StringComparison.OrdinalIgnoreCase); + if (isVoid) return "send_no_content"; + if (returnTypeName.Equals("String", StringComparison.Ordinal)) + return "send_primitive_string"; + if (returnType.IsCollection) return "send_raw_collection"; + return "send_raw"; + } + + private string GetDeserializationMethodName(CodeTypeBase propType, CodeMethod method) + { + var isCollection = propType.CollectionKind != CodeTypeCollectionKind.None; + var propertyType = conventions.GetTypeString(propType, method, false, false); + if (propType is CodeType currentType) + { + if (isCollection) + { + if (currentType.TypeDefinition == null) + { + // Primitive collection: use concrete method names (no generics for dyn-compatibility) + return propertyType switch + { + "String" => "get_collection_of_primitive_string_values()", + "i32" => "get_collection_of_primitive_i32_values()", + "i64" => "get_collection_of_primitive_i64_values()", + "f64" => "get_collection_of_primitive_f64_values()", + "bool" => "get_collection_of_primitive_bool_values()", + _ => $"get_raw_value().and_then(|v| serde_json::from_value(v).ok()).unwrap_or_default()", + }; + } + // Enum or object collection: use serde via raw value + return $"get_raw_value().and_then(|v| serde_json::from_value(v).ok())"; + } + if (currentType.TypeDefinition is CodeEnum) + return $"get_string_value().and_then(|s| serde_json::from_value(serde_json::Value::String(s)).ok())"; + } + return propertyType switch + { + "Vec" => "get_byte_array_value()", + "uuid::Uuid" => "get_uuid_value()", + "chrono::DateTime" => "get_date_time_value()", + "chrono::NaiveDate" => "get_date_value()", + "chrono::NaiveTime" => "get_time_value()", + "chrono::Duration" => "get_duration_value()", + "String" => "get_string_value()", + "bool" => "get_bool_value()", + "i32" => "get_i32_value()", + "i64" => "get_i64_value()", + "f32" => "get_f32_value()", + "f64" => "get_f64_value()", + _ => $"get_raw_value().and_then(|v| serde_json::from_value(v).ok())", + }; + } + + private string GetSerializationMethodName(CodeTypeBase propType, CodeMethod method) + { + var isCollection = propType.CollectionKind != CodeTypeCollectionKind.None; + var propertyType = conventions.GetTypeString(propType, method, false, false); + if (propType is CodeType currentType) + { + if (isCollection) + { + // All collections serialized via serde raw value + return "COLLECTION_SERDE"; + } + if (currentType.TypeDefinition is CodeEnum) + return "ENUM_SERDE"; + } + return propertyType switch + { + "Vec" => "write_byte_array_value", + "uuid::Uuid" => "write_uuid_value", + "chrono::DateTime" => "write_date_time_value", + "chrono::NaiveDate" => "write_date_value", + "chrono::NaiveTime" => "write_time_value", + "chrono::Duration" => "write_duration_value", + "String" => "write_string_value", + "bool" => "write_bool_value", + "i32" => "write_i32_value", + "i64" => "write_i64_value", + "f32" => "write_f32_value", + "f64" => "write_f64_value", + _ => "OBJECT_SERDE", + }; + } + + private void WriteMethodDocumentation(CodeMethod code, LanguageWriter writer) + { + conventions.WriteLongDescription(code, writer); + foreach (var paramWithDescription in code.Parameters + .Where(static x => x.Documentation.DescriptionAvailable) + .OrderBy(static x => x.Name, StringComparer.OrdinalIgnoreCase)) + conventions.WriteParameterDescription(paramWithDescription, writer); + conventions.WriteDeprecationAttribute(code, writer); + } + + private static readonly BaseCodeParameterOrderComparer parameterOrderComparer = new(); + + private void WriteMethodPrototype(CodeMethod code, CodeClass parentClass, LanguageWriter writer, string returnType, bool inherits, bool isVoid) + { + var isConstructor = code.IsOfKind(CodeMethodKind.Constructor, CodeMethodKind.ClientConstructor, CodeMethodKind.RawUrlConstructor); + var asyncKeyword = code.IsAsync ? "async " : ""; + var methodName = isConstructor ? "new" : code.Name.ToSnakeCase(); + if (code.IsOfKind(CodeMethodKind.RawUrlConstructor)) + methodName = "with_raw_url"; + + var selfParam = code.IsStatic ? "" : "&self, "; + if (isConstructor) + selfParam = ""; + + var parameters = string.Join(", ", code.Parameters + .OrderBy(static x => x, parameterOrderComparer) + .Select(p => conventions.GetParameterSignature(p, code))); + + var allParams = string.IsNullOrEmpty(parameters) ? selfParam.TrimEnd(' ').TrimEnd(',') : $"{selfParam}{parameters}"; + + // Methods that use `?` operator need Result return type + var needsResult = code.IsAsync || code.IsOfKind(CodeMethodKind.Serializer, CodeMethodKind.RequestGenerator); + + string returnTypeStr; + if (isConstructor) + { + returnTypeStr = " -> Self"; + } + else if (isVoid && needsResult) + { + returnTypeStr = " -> Result<(), Box>"; + } + else if (isVoid) + { + returnTypeStr = ""; + } + else if (needsResult) + { + returnTypeStr = $" -> Result<{returnType}, Box>"; + } + else + { + returnTypeStr = $" -> {returnType}"; + } + + var visibility = conventions.GetAccessModifier(code.Access); + writer.StartBlock($"{visibility}{asyncKeyword}fn {methodName}({allParams}){returnTypeStr} {{"); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodePropertyWriter.cs b/src/Kiota.Builder/Writers/Rust/CodePropertyWriter.cs new file mode 100644 index 0000000000..45b5cc4d5b --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodePropertyWriter.cs @@ -0,0 +1,60 @@ +using System; +using System.Linq; +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodePropertyWriter : BaseElementWriter +{ + public CodePropertyWriter(RustConventionService conventionService) : base(conventionService) { } + + public override void WriteCodeElement(CodeProperty codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + if (codeElement.ExistsInExternalBaseType || conventions.ErrorClassPropertyExistsInSuperClass(codeElement)) return; + + var propertyType = conventions.GetTypeString(codeElement.Type, codeElement); + var propertyName = codeElement.Name.ToSnakeCase(); + + switch (codeElement.Kind) + { + case CodePropertyKind.RequestBuilder: + // Request builder properties are accessed via methods in Rust, skip field declaration + break; + case CodePropertyKind.QueryParameter when codeElement.IsNameEscaped: + conventions.WriteShortDescription(codeElement, writer); + conventions.WriteDeprecationAttribute(codeElement, writer); + WriteField(writer, propertyName, propertyType, codeElement); + break; + case CodePropertyKind.Custom when !string.IsNullOrEmpty(codeElement.SerializationName) && !codeElement.SerializationName.Equals(codeElement.Name, StringComparison.Ordinal): + conventions.WriteShortDescription(codeElement, writer); + conventions.WriteDeprecationAttribute(codeElement, writer); + writer.WriteLine($"#[serde(rename = \"{codeElement.WireName}\")]"); + WriteField(writer, propertyName, propertyType, codeElement); + break; + case CodePropertyKind.AdditionalData: + conventions.WriteShortDescription(codeElement, writer); + conventions.WriteDeprecationAttribute(codeElement, writer); + writer.WriteLine("#[serde(flatten)]"); + writer.WriteLine($"pub {propertyName}: std::collections::HashMap,"); + break; + case CodePropertyKind.BackingStore: + // Backing store is not directly mapped to Rust + break; + default: + conventions.WriteShortDescription(codeElement, writer); + conventions.WriteDeprecationAttribute(codeElement, writer); + WriteField(writer, propertyName, propertyType, codeElement); + break; + } + } + + private static void WriteField(LanguageWriter writer, string propertyName, string propertyType, CodeProperty codeElement) + { + var isNullable = codeElement.Type.IsNullable && !propertyType.StartsWith("Option<", StringComparison.Ordinal); + var finalType = isNullable ? $"Option<{propertyType}>" : propertyType; + writer.WriteLine($"pub {propertyName}: {finalType},"); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/RustConventionService.cs b/src/Kiota.Builder/Writers/Rust/RustConventionService.cs new file mode 100644 index 0000000000..2de7055005 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/RustConventionService.cs @@ -0,0 +1,256 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +using static Kiota.Builder.CodeDOM.CodeTypeBase; + +namespace Kiota.Builder.Writers.Rust; + +public class RustConventionService : CommonLanguageConventionService +{ + internal static readonly HashSet ErrorClassProperties = new(StringComparer.OrdinalIgnoreCase) { "message", "status_code", "response_headers" }; + public override string StreamTypeName => "Vec"; + public override string VoidTypeName => "()"; + public override string DocCommentPrefix => "/// "; + internal const string NullableMarkerStr = "Option<"; + public override string ParseNodeInterfaceName => "ParseNode"; + private const string ReferenceTypePrefix = "[`"; + private const string ReferenceTypeSuffix = "`]"; + + private static readonly HashSet PrimitiveTypes = new(StringComparer.OrdinalIgnoreCase) { + "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", + "f32", "f64", "bool", "String", + }; + + public override bool WriteShortDescription(IDocumentedElement element, LanguageWriter writer, string prefix = "", string suffix = "") + { + ArgumentNullException.ThrowIfNull(writer); + ArgumentNullException.ThrowIfNull(element); + if (!element.Documentation.DescriptionAvailable) return false; + if (element is not CodeElement codeElement) return false; + + var description = element.Documentation.GetDescription(x => GetTypeReferenceForDocComment(x, codeElement), ReferenceTypePrefix, ReferenceTypeSuffix); + writer.WriteLine($"{DocCommentPrefix}{description}"); + + return true; + } + + public void WriteParameterDescription(CodeParameter element, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(writer); + ArgumentNullException.ThrowIfNull(element); + var description = element.Documentation.GetDescription(x => GetTypeReferenceForDocComment(x, element), ReferenceTypePrefix, ReferenceTypeSuffix); + writer.WriteLine($"{DocCommentPrefix}* `{element.Name.ToSnakeCase()}` - {description}"); + } + + public void WriteAutogeneratedMessage(LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(writer); + writer.WriteLine("// auto-generated by kiota"); + } + + public void WriteLintSuppression(LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(writer); + writer.WriteLine("#![allow(unused_imports, dead_code, unused_variables, clippy::all)]"); + } + + public void WriteLongDescription(CodeElement element, LanguageWriter writer, IEnumerable? additionalRemarks = default) + { + ArgumentNullException.ThrowIfNull(writer); + if (element is not IDocumentedElement documentedElement || documentedElement.Documentation is not CodeDocumentation documentation) return; + additionalRemarks ??= []; + var remarks = additionalRemarks.Where(static x => !string.IsNullOrEmpty(x)).ToArray(); + if (documentation.DescriptionAvailable || documentation.ExternalDocumentationAvailable || remarks.Length != 0) + { + if (documentation.DescriptionAvailable) + { + var description = documentedElement.Documentation.GetDescription(x => GetTypeReferenceForDocComment(x, element), ReferenceTypePrefix, ReferenceTypeSuffix); + writer.WriteLine($"{DocCommentPrefix}{description}"); + } + foreach (var additionalRemark in remarks) + writer.WriteLine($"{DocCommentPrefix}{additionalRemark}"); + if (element is IDeprecableElement deprecableElement && deprecableElement.Deprecation is not null && deprecableElement.Deprecation.IsDeprecated) + foreach (var additionalComment in GetDeprecationInformationForDocumentationComment(deprecableElement)) + writer.WriteLine($"{DocCommentPrefix}{additionalComment}"); + + if (documentation.ExternalDocumentationAvailable) + writer.WriteLine($"{DocCommentPrefix}[{documentation.DocumentationLabel}]({documentation.DocumentationLink})"); + } + } + + internal string GetTypeReferenceForDocComment(CodeTypeBase code, CodeElement targetElement) + { + if (code is CodeType codeType && codeType.TypeDefinition is CodeMethod method) + return $"{GetTypeString(new CodeType { TypeDefinition = method.Parent, IsExternal = false }, targetElement)}::{GetTypeString(code, targetElement)}"; + return $"{GetTypeString(code, targetElement)}"; + } + + private string[] GetDeprecationInformationForDocumentationComment(IDeprecableElement element) + { + if (element.Deprecation is null || !element.Deprecation.IsDeprecated) return Array.Empty(); + + var versionComment = string.IsNullOrEmpty(element.Deprecation.Version) ? string.Empty : $" as of {element.Deprecation.Version}"; + var dateComment = element.Deprecation.Date is null ? string.Empty : $" on {element.Deprecation.Date.Value.Date.ToString("yyyy-MM-dd", CultureInfo.InvariantCulture)}"; + var removalComment = element.Deprecation.RemovalDate is null ? string.Empty : $" and will be removed {element.Deprecation.RemovalDate.Value.Date.ToString("yyyy-MM-dd", CultureInfo.InvariantCulture)}"; + return [ + $"# Deprecated", + $"{element.Deprecation.GetDescription(type => GetTypeString(type, (element as CodeElement)!))}{versionComment}{dateComment}{removalComment}" + ]; + } + + public override string GetAccessModifier(AccessModifier access) + { + return access switch + { + AccessModifier.Public => "pub ", + AccessModifier.Protected => "pub(crate) ", + AccessModifier.Private => "", + _ => "pub ", + }; + } + + internal void AddRequestBuilderBody(CodeClass parentClass, string returnType, LanguageWriter writer, string? urlTemplateVarName = default, string? prefix = default, IEnumerable? pathParameters = default, IEnumerable? customParameters = default) + { + if (parentClass.GetPropertyOfKind(CodePropertyKind.PathParameters) is CodeProperty pathParametersProp && + parentClass.GetPropertyOfKind(CodePropertyKind.RequestAdapter) is CodeProperty requestAdapterProp) + { + var pathParametersSuffix = !(pathParameters?.Any() ?? false) ? string.Empty : + $", {string.Join(", ", pathParameters.Select(x => $"{x.Name.ToSnakeCase()}"))}"; + var urlTplRef = string.IsNullOrEmpty(urlTemplateVarName) ? $"self.{pathParametersProp.Name.ToSnakeCase()}.clone()" : urlTemplateVarName; + if (customParameters?.Any() ?? false) + { + urlTplRef = TempDictionaryVarName; + writer.WriteLine($"let mut {urlTplRef} = self.{pathParametersProp.Name.ToSnakeCase()}.clone();"); + foreach (var param in customParameters) + writer.WriteLine($"{urlTplRef}.insert(\"{param.SerializationName}\".to_string(), {param.Name.ToSnakeCase()}.to_string());"); + } + writer.WriteLine($"{prefix}{returnType}::new({urlTplRef}, self.{requestAdapterProp.Name.ToSnakeCase()}.clone(){pathParametersSuffix})"); + } + } + + public override string TempDictionaryVarName => "url_tpl_params"; + + internal void AddParametersAssignment(LanguageWriter writer, CodeTypeBase pathParametersType, string pathParametersReference, string varName = "", params (CodeTypeBase, string, string)[] parameters) + { + if (pathParametersType == null) return; + if (string.IsNullOrEmpty(varName)) + { + varName = TempDictionaryVarName; + writer.WriteLine($"let mut {varName} = {pathParametersReference}.clone();"); + } + if (parameters.Length != 0) + { + writer.WriteLines(parameters.Select(p => + { + var (ct, name, identName) = p; + var identSnake = identName.ToSnakeCase(); + string nullCheck = string.Empty; + if (ct.CollectionKind == CodeTypeCollectionKind.None && ct.IsNullable) + { + nullCheck = $"if let Some(val) = &{identSnake} {{ "; + return $"{nullCheck}{varName}.insert(\"{name}\".to_string(), val.to_string()); }}"; + } + return $"{varName}.insert(\"{name}\".to_string(), {identSnake}.to_string());"; + }).ToArray()); + } + } + + public override string GetTypeString(CodeTypeBase code, CodeElement targetElement, bool includeCollectionInformation = true, LanguageWriter? writer = null) + { + return GetTypeString(code, targetElement, includeCollectionInformation, true); + } + + public string GetTypeString(CodeTypeBase code, CodeElement targetElement, bool includeCollectionInformation, bool includeNullableInformation, bool includeActionInformation = true) + { + ArgumentNullException.ThrowIfNull(targetElement); + if (code is CodeComposedTypeBase) + throw new InvalidOperationException($"Rust does not support union types, the union type {code.Name} should have been filtered out by the refiner"); + if (code is CodeType currentType) + { + var typeName = TranslateType(currentType); + + var collectionWrapper = currentType.CollectionKind != CodeTypeCollectionKind.None && includeCollectionInformation; + var wrappedType = collectionWrapper ? $"Vec<{typeName}>" : typeName; + + var genericParameters = currentType.GenericTypeParameterValues.Any() ? + $"<{string.Join(", ", currentType.GenericTypeParameterValues.Select(x => GetTypeString(x, targetElement, includeCollectionInformation)))}>" : + string.Empty; + if (!string.IsNullOrEmpty(genericParameters)) + wrappedType = $"{typeName}{genericParameters}"; + + if (currentType.ActionOf && includeActionInformation) + return $"Box"; + + if (currentType.IsNullable && includeNullableInformation) + return $"Option<{wrappedType}>"; + + return wrappedType; + } + + throw new InvalidOperationException($"type of type {code?.GetType()} is unknown"); + } + + public override string TranslateType(CodeType type) + { + ArgumentNullException.ThrowIfNull(type); + return type.Name.ToLowerInvariant() switch + { + "integer" or "int32" => "i32", + "int64" => "i64", + "sbyte" => "i8", + "byte" => "u8", + "boolean" => "bool", + "string" => "String", + "float" => "f32", + "double" or "decimal" => "f64", + "object" => "serde_json::Value", + "void" => "()", + "binary" or "base64" or "base64url" => "Vec", + "iparsenode" or "parsenode" => "&dyn ParseNode", + "iserializationwriter" or "serializationwriter" => "&mut dyn SerializationWriter", + "requestadapter" or "irequestadapter" => "std::sync::Arc", + _ => type.Name.Contains("::", StringComparison.Ordinal) ? type.Name : // qualified types (chrono::DateTime, uuid::Uuid) from refiner replacements + (type.Name.ToFirstCharacterUpperCase() is string typeName && !string.IsNullOrEmpty(typeName) ? typeName : "serde_json::Value"), + }; + } + + public bool IsPrimitiveType(string typeName) + { + if (string.IsNullOrEmpty(typeName)) return false; + var cleanName = typeName; + if (cleanName.StartsWith("Option<", StringComparison.Ordinal)) + cleanName = cleanName[7..]; + cleanName = cleanName.TrimEnd('>'); + return PrimitiveTypes.Contains(cleanName) || cleanName.StartsWith("chrono::", StringComparison.Ordinal); + } + + public override string GetParameterSignature(CodeParameter parameter, CodeElement targetElement, LanguageWriter? writer = null) + { + ArgumentNullException.ThrowIfNull(parameter); + var parameterType = GetTypeString(parameter.Type, targetElement, true, !parameter.Optional); + var paramName = parameter.Name.ToSnakeCase(); + if (parameter.Optional && !parameterType.StartsWith("Option<", StringComparison.Ordinal)) + parameterType = $"Option<{parameterType}>"; + return $"{paramName}: {parameterType}"; + } + + internal void WriteDeprecationAttribute(IDeprecableElement element, LanguageWriter writer) + { + if (element.Deprecation is null || !element.Deprecation.IsDeprecated) return; + + var versionComment = string.IsNullOrEmpty(element.Deprecation.Version) ? string.Empty : $" as of {element.Deprecation.Version}"; + var dateComment = element.Deprecation.Date is null ? string.Empty : $" on {element.Deprecation.Date.Value.Date.ToString("yyyy-MM-dd", CultureInfo.InvariantCulture)}"; + var removalComment = element.Deprecation.RemovalDate is null ? string.Empty : $" and will be removed {element.Deprecation.RemovalDate.Value.Date.ToString("yyyy-MM-dd", CultureInfo.InvariantCulture)}"; + writer.WriteLine($"#[deprecated(note = \"{element.Deprecation.GetDescription(type => GetTypeString(type, (element as CodeElement)!))}{versionComment}{dateComment}{removalComment}\")]"); + } + + public bool ErrorClassPropertyExistsInSuperClass(CodeProperty codeElement) + { + return codeElement?.Parent is CodeClass parentClass && parentClass.IsErrorDefinition && ErrorClassProperties.Contains(codeElement.Name); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/RustWriter.cs b/src/Kiota.Builder/Writers/Rust/RustWriter.cs new file mode 100644 index 0000000000..759e2a2292 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/RustWriter.cs @@ -0,0 +1,17 @@ +using Kiota.Builder.PathSegmenters; + +namespace Kiota.Builder.Writers.Rust; + +public class RustWriter : LanguageWriter +{ + public RustWriter(string rootPath, string clientNamespaceName) + { + PathSegmenter = new RustPathSegmenter(rootPath, clientNamespaceName); + var conventionService = new RustConventionService(); + AddOrReplaceCodeElementWriter(new CodeClassDeclarationWriter(conventionService, clientNamespaceName, (RustPathSegmenter)PathSegmenter)); + AddOrReplaceCodeElementWriter(new CodeBlockEndWriter()); + AddOrReplaceCodeElementWriter(new CodeEnumWriter(conventionService)); + AddOrReplaceCodeElementWriter(new CodeMethodWriter(conventionService)); + AddOrReplaceCodeElementWriter(new CodePropertyWriter(conventionService)); + } +} diff --git a/src/kiota/appsettings.json b/src/kiota/appsettings.json index 03b4814131..a28a6ffaa2 100644 --- a/src/kiota/appsettings.json +++ b/src/kiota/appsettings.json @@ -381,6 +381,38 @@ ], "DependencyInstallCommand": "" }, + "Rust": { + "MaturityLevel": "Experimental", + "SupportExperience": "Community", + "Dependencies": [ + { + "Name": "kiota-abstractions", + "Version": "0.1.0", + "Type": "Abstractions" + }, + { + "Name": "kiota-http-reqwest", + "Version": "0.1.0", + "Type": "Http" + }, + { + "Name": "kiota-serialization-json", + "Version": "0.1.0", + "Type": "Serialization" + }, + { + "Name": "kiota-serialization-text", + "Version": "0.1.0", + "Type": "Serialization" + }, + { + "Name": "kiota-serialization-form", + "Version": "0.1.0", + "Type": "Serialization" + } + ], + "DependencyInstallCommand": "cargo add {0}@{1}" + }, "HTTP": { "MaturityLevel": "Preview", "SupportExperience": "Microsoft", diff --git a/tests/Kiota.Builder.Tests/Refiners/RustLanguageRefinerTests.cs b/tests/Kiota.Builder.Tests/Refiners/RustLanguageRefinerTests.cs new file mode 100644 index 0000000000..869c21c1c2 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Refiners/RustLanguageRefinerTests.cs @@ -0,0 +1,229 @@ +using System; +using System.Linq; +using System.Threading.Tasks; +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Configuration; +using Kiota.Builder.Refiners; + +using Xunit; + +namespace Kiota.Builder.Tests.Refiners; + +public class RustLanguageRefinerTests +{ + private readonly CodeNamespace root = CodeNamespace.InitRootNamespace(); + #region CommonLanguageRefinerTests + [Fact] + public async Task AddsExceptionInheritanceOnErrorClasses() + { + var model = root.AddClass(new CodeClass + { + Name = "somemodel", + Kind = CodeClassKind.Model, + IsErrorDefinition = true, + }).First(); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + + var declaration = model.StartBlock; + + Assert.Contains("ApiError", declaration.Usings.Select(x => x.Name)); + Assert.Equal("ApiError", declaration.Inherits.Name); + } + [Fact] + public async Task AddsUsingsForErrorTypesForRequestExecutor() + { + var requestBuilder = root.AddClass(new CodeClass + { + Name = "somerequestbuilder", + Kind = CodeClassKind.RequestBuilder, + }).First(); + var subNS = root.AddNamespace($"{root.Name}.subns"); + var errorClass = subNS.AddClass(new CodeClass + { + Name = "Error4XX", + Kind = CodeClassKind.Model, + IsErrorDefinition = true, + }).First(); + var requestExecutor = requestBuilder.AddMethod(new CodeMethod + { + Name = "get", + Kind = CodeMethodKind.RequestExecutor, + ReturnType = new CodeType + { + Name = "string" + }, + }).First(); + requestExecutor.AddErrorMapping("4XX", new CodeType + { + Name = "Error4XX", + TypeDefinition = errorClass, + }); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + + var declaration = requestBuilder.StartBlock; + + Assert.Contains("Error4XX", declaration.Usings.Select(x => x.Declaration?.Name)); + } + [Fact] + public async Task EscapesReservedKeywordsInInternalDeclaration() + { + var model = root.AddClass(new CodeClass + { + Name = "break", + Kind = CodeClassKind.Model + }).First(); + var nUsing = new CodeUsing + { + Name = "some.ns", + }; + nUsing.Declaration = new CodeType + { + IsExternal = false, + TypeDefinition = model + }; + model.AddUsing(nUsing); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + // Rust escapes reserved words with r# prefix + Assert.NotEqual("break", nUsing.Declaration.Name, StringComparer.OrdinalIgnoreCase); + } + [Fact] + public async Task EscapesReservedKeywords() + { + var model = root.AddClass(new CodeClass + { + Name = "break", + Kind = CodeClassKind.Model + }).First(); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + Assert.NotEqual("break", model.Name, StringComparer.OrdinalIgnoreCase); + } + [Fact] + public async Task ReplacesDateTimeOffsetWithChrono() + { + var model = root.AddClass(new CodeClass + { + Name = "model", + Kind = CodeClassKind.Model + }).First(); + var method = model.AddMethod(new CodeMethod + { + Name = "method", + Kind = CodeMethodKind.RequestExecutor, + ReturnType = new CodeType + { + Name = "DateTimeOffset" + }, + }).First(); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + Assert.NotEmpty(model.StartBlock.Usings); + Assert.Contains("chrono", model.StartBlock.Usings.Select(x => x.Name)); + } + [Fact] + public async Task ReplacesGuidWithUuid() + { + var model = root.AddClass(new CodeClass + { + Name = "model", + Kind = CodeClassKind.Model + }).First(); + var method = model.AddMethod(new CodeMethod + { + Name = "method", + Kind = CodeMethodKind.RequestExecutor, + ReturnType = new CodeType + { + Name = "Guid" + }, + }).First(); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + Assert.NotEmpty(model.StartBlock.Usings); + Assert.Contains("uuid", model.StartBlock.Usings.Select(x => x.Name)); + } + [Fact] + public async Task ReplacesIndexersByMethods() + { + var collectionNS = root.AddNamespace("collection"); + var itemsNs = collectionNS.AddNamespace($"{collectionNS.Name}.items"); + var requestBuilder = itemsNs.AddClass(new CodeClass + { + Name = "requestBuilder", + Kind = CodeClassKind.RequestBuilder + }).First(); + requestBuilder.AddProperty(new CodeProperty + { + Name = "urlTemplate", + DefaultValue = "path", + Kind = CodePropertyKind.UrlTemplate, + Type = new CodeType + { + Name = "string", + } + }); + requestBuilder.AddIndexer(new CodeIndexer + { + Name = "idx", + ReturnType = new CodeType + { + Name = requestBuilder.Name, + TypeDefinition = requestBuilder, + }, + IndexParameter = new() + { + Name = "id", + Type = new CodeType + { + Name = "string", + }, + } + }); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + Assert.Single(requestBuilder.Methods, x => x.IsOfKind(CodeMethodKind.IndexerBackwardCompatibility)); + } + [Fact] + public async Task RemovesCancellationParameter() + { + var model = root.AddClass(new CodeClass + { + Name = "model", + Kind = CodeClassKind.RequestBuilder + }).First(); + var method = model.AddMethod(new CodeMethod + { + Name = "getAction", + Kind = CodeMethodKind.RequestExecutor, + ReturnType = new CodeType + { + Name = "string" + }, + }).First(); + method.AddParameter(new CodeParameter + { + Name = "cancellationToken", + Kind = CodeParameterKind.Cancellation, + Type = new CodeType + { + Name = "CancellationToken" + }, + }); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + Assert.DoesNotContain(method.Parameters, x => x.IsOfKind(CodeParameterKind.Cancellation)); + } + [Fact] + public async Task AddsDefaultImports() + { + var model = root.AddClass(new CodeClass + { + Name = "model", + Kind = CodeClassKind.Model, + }).First(); + model.AddProperty(new CodeProperty + { + Name = "name", + Kind = CodePropertyKind.Custom, + Type = new CodeType { Name = "string" }, + }); + await ILanguageRefiner.RefineAsync(new GenerationConfiguration { Language = GenerationLanguage.Rust }, root, TestContext.Current.CancellationToken); + Assert.Contains("Parsable", model.StartBlock.Usings.Select(x => x.Name)); + } + #endregion +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassDeclarationWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassDeclarationWriterTests.cs new file mode 100644 index 0000000000..0e9a03276e --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassDeclarationWriterTests.cs @@ -0,0 +1,75 @@ +using System; +using System.IO; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Writers; +using Kiota.Builder.Writers.Rust; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodeClassDeclarationWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private const string DefaultNameSpace = "ns"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + private readonly CodeClassDeclarationWriter codeElementWriter; + private readonly CodeClass parentClass; + private readonly CodeNamespace root; + + public CodeClassDeclarationWriterTests() + { + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + codeElementWriter = new CodeClassDeclarationWriter(new RustConventionService(), DefaultNameSpace, (Builder.PathSegmenters.RustPathSegmenter)writer.PathSegmenter); + tw = new StringWriter(); + writer.SetTextWriter(tw); + root = CodeNamespace.InitRootNamespace(); + parentClass = new() + { + Name = "parentClass" + }; + root.AddClass(parentClass); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + [Fact] + public void WritesSimpleDeclaration() + { + codeElementWriter.WriteCodeElement(parentClass.StartBlock, writer); + var result = tw.ToString(); + Assert.Contains("pub struct", result); + Assert.Contains("ParentClass", result); + } + [Fact] + public void WritesModelWithSerdeDerives() + { + parentClass.Kind = CodeClassKind.Model; + codeElementWriter.WriteCodeElement(parentClass.StartBlock, writer); + var result = tw.ToString(); + Assert.Contains("#[derive(Debug, Clone, Default, Serialize, Deserialize)]", result); + Assert.Contains("pub struct", result); + } + [Fact] + public void WritesRequestBuilderWithoutSerdeDerives() + { + parentClass.Kind = CodeClassKind.RequestBuilder; + codeElementWriter.WriteCodeElement(parentClass.StartBlock, writer); + var result = tw.ToString(); + Assert.DoesNotContain("Serialize", result); + Assert.DoesNotContain("Deserialize", result); + Assert.DoesNotContain("#[derive(", result); // request builders have Box fields, no derives + } + [Fact] + public void WritesAutoGeneratedComment() + { + codeElementWriter.WriteCodeElement(parentClass.StartBlock, writer); + var result = tw.ToString(); + Assert.Contains("// auto-generated by kiota", result); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassEndWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassEndWriterTests.cs new file mode 100644 index 0000000000..a535e43b49 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassEndWriterTests.cs @@ -0,0 +1,33 @@ +using System; +using System.IO; + +using Kiota.Builder.Writers; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodeClassEndWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + public CodeClassEndWriterTests() + { + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + tw = new StringWriter(); + writer.SetTextWriter(tw); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + [Fact] + public void WritesBlockEnd() + { + // Just verify writer instantiates properly for block ends + Assert.NotNull(writer); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodeEnumWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodeEnumWriterTests.cs new file mode 100644 index 0000000000..bff9ca5735 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodeEnumWriterTests.cs @@ -0,0 +1,96 @@ +using System; +using System.IO; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Writers; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodeEnumWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + private readonly CodeEnum currentEnum; + private const string EnumName = "SomeEnum"; + public CodeEnumWriterTests() + { + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + tw = new StringWriter(); + writer.SetTextWriter(tw); + var root = CodeNamespace.InitRootNamespace(); + currentEnum = root.AddEnum(new CodeEnum + { + Name = EnumName, + }).First(); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + [Fact] + public void WritesEnum() + { + const string optionName = "Option1"; + currentEnum.AddOption(new CodeEnumOption { Name = optionName, SerializationName = "option1" }); + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Contains("#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]", result); + Assert.Contains("pub enum SomeEnum", result); + Assert.Contains(optionName, result); + } + [Fact] + public void DoesntWriteAnythingOnNoOption() + { + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Empty(result); + } + [Fact] + public void WritesEnumOptionDescription() + { + var option = new CodeEnumOption + { + Documentation = new() + { + DescriptionTemplate = "Some option description", + }, + Name = "Option1", + }; + currentEnum.AddOption(option); + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Contains("///", result); + Assert.Contains(option.Documentation.DescriptionTemplate, result); + } + [Fact] + public void WritesEnumSerializationValue() + { + var optionName = "Plus1"; + var serializationValue = "+1"; + var option = new CodeEnumOption + { + Name = optionName, + SerializationName = serializationValue + }; + currentEnum.AddOption(option); + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Contains($"#[serde(rename = \"{serializationValue}\")]", result); + Assert.Contains(optionName, result); + } + [Fact] + public void WritesEnumDescription() + { + currentEnum.Documentation.DescriptionTemplate = "Some enum description"; + currentEnum.AddOption(new CodeEnumOption { Name = "Option1" }); + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Contains("/// Some enum description", result); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodeMethodWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodeMethodWriterTests.cs new file mode 100644 index 0000000000..2d1bf25dbb --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodeMethodWriterTests.cs @@ -0,0 +1,297 @@ +using System; +using System.IO; +using System.Linq; +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; +using Kiota.Builder.Writers; +using Kiota.Builder.Writers.Rust; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodeMethodWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + private CodeMethod method; + private CodeClass parentClass; + private readonly CodeNamespace root; + private const string MethodName = "methodName"; + private const string ReturnTypeName = "Somecustomtype"; + private const string MethodDescription = "some description"; + private const string ParamDescription = "some parameter description"; + private const string ParamName = "paramName"; + public CodeMethodWriterTests() + { + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + tw = new StringWriter(); + writer.SetTextWriter(tw); + root = CodeNamespace.InitRootNamespace(); + } + private void setup(bool withInheritance = false) + { + if (parentClass != null) + throw new InvalidOperationException("setup() must only be called once"); + CodeClass baseClass = default; + if (withInheritance) + { + baseClass = root.AddClass(new CodeClass + { + Name = "SomeParentClass", + }).First(); + baseClass.AddProperty(new CodeProperty + { + Name = "definedInParent", + Type = new CodeType + { + Name = "String" + }, + Kind = CodePropertyKind.Custom, + }); + } + parentClass = new CodeClass + { + Name = "ParentClass" + }; + if (withInheritance) + { + parentClass.StartBlock.Inherits = new CodeType + { + Name = "SomeParentClass", + TypeDefinition = baseClass + }; + } + root.AddClass(parentClass); + method = new CodeMethod + { + Name = MethodName, + ReturnType = new CodeType + { + Name = ReturnTypeName + } + }; + parentClass.AddMethod(method); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + private void AddRequestProperties() + { + parentClass.StartBlock.Inherits = new CodeType + { + Name = "BaseRequestBuilder", + IsExternal = true, + }; + parentClass.AddProperty(new CodeProperty + { + Name = "requestAdapter", + Kind = CodePropertyKind.RequestAdapter, + Type = new CodeType + { + Name = "RequestAdapter", + } + }); + parentClass.AddProperty(new CodeProperty + { + Name = "pathParameters", + Kind = CodePropertyKind.PathParameters, + Type = new CodeType + { + Name = "String", + } + }); + parentClass.AddProperty(new CodeProperty + { + Name = "urlTemplate", + Kind = CodePropertyKind.UrlTemplate, + Type = new CodeType + { + Name = "String", + } + }); + } + private void AddSerializationProperties() + { + parentClass.AddProperty(new CodeProperty + { + Name = "additionalData", + Kind = CodePropertyKind.AdditionalData, + Type = new CodeType + { + Name = "String" + }, + Getter = new CodeMethod + { + Name = "getAdditionalData", + ReturnType = new CodeType + { + Name = "String" + } + }, + Setter = new CodeMethod + { + Name = "setAdditionalData", + ReturnType = new CodeType + { + Name = "String" + } + } + }); + parentClass.AddProperty(new CodeProperty + { + Name = "dummyProp", + Type = new CodeType + { + Name = "String" + }, + Kind = CodePropertyKind.Custom, + }); + parentClass.AddProperty(new CodeProperty + { + Name = "dummyUCaseProp", + Type = new CodeType + { + Name = "String" + }, + Kind = CodePropertyKind.Custom, + SerializationName = "DummyUCaseProp", + }); + } + [Fact] + public void WritesSerializerBody() + { + setup(); + parentClass.Kind = CodeClassKind.Model; + method.Kind = CodeMethodKind.Serializer; + method.Name = "serialize"; + method.IsAsync = false; + AddSerializationProperties(); + method.AddParameter(new CodeParameter + { + Name = "writer", + Kind = CodeParameterKind.Serializer, + Type = new CodeType + { + Name = "SerializationWriter" + } + }); + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("pub fn serialize", result); + Assert.Contains("write_string_value", result); + Assert.Contains("Ok(())", result); + } + [Fact] + public void WritesDeserializerBody() + { + setup(); + parentClass.Kind = CodeClassKind.Model; + method.Kind = CodeMethodKind.Deserializer; + method.Name = "get_field_deserializers"; + method.IsAsync = false; + method.ReturnType = new CodeType + { + Name = "HashMap>" + }; + AddSerializationProperties(); + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("pub fn get_field_deserializers", result); + Assert.Contains("HashMap", result); + } + [Fact] + public void WritesConstructorBody() + { + setup(); + method.Kind = CodeMethodKind.Constructor; + method.IsAsync = false; + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("pub fn new", result); + Assert.Contains("Self", result); + } + [Fact] + public void WritesMethodDescription() + { + setup(); + method.Documentation.DescriptionTemplate = MethodDescription; + method.Kind = CodeMethodKind.Constructor; + method.IsAsync = false; + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("///", result); + Assert.Contains(MethodDescription, result); + } + [Fact] + public void WritesAsyncMethod() + { + setup(); + method.Kind = CodeMethodKind.RequestExecutor; + method.IsAsync = true; + method.HttpMethod = HttpMethod.Get; + AddRequestProperties(); + method.AddParameter(new CodeParameter + { + Name = "requestConfiguration", + Kind = CodeParameterKind.RequestConfiguration, + Type = new CodeType + { + Name = "RequestConfiguration" + }, + Optional = true, + }); + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("pub async fn", result); + Assert.Contains("Result<", result); + } + [Fact] + public void WritesRequestGeneratorBody() + { + setup(); + method.Kind = CodeMethodKind.RequestGenerator; + method.IsAsync = false; + method.HttpMethod = HttpMethod.Get; + AddRequestProperties(); + method.AddParameter(new CodeParameter + { + Name = "requestConfiguration", + Kind = CodeParameterKind.RequestConfiguration, + Type = new CodeType + { + Name = "RequestConfiguration" + }, + Optional = true, + }); + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("RequestInformation", result); + Assert.Contains("url_template", result); + } + [Fact] + public void WritesIndexerBackwardCompatibility() + { + setup(); + method.Kind = CodeMethodKind.IndexerBackwardCompatibility; + method.IsAsync = false; + AddRequestProperties(); + method.AddParameter(new CodeParameter + { + Name = "id", + Kind = CodeParameterKind.Custom, + Type = new CodeType + { + Name = "string" + }, + SerializationName = "id", + }); + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("pub fn", result); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodePropertyWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodePropertyWriterTests.cs new file mode 100644 index 0000000000..67a687177c --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodePropertyWriterTests.cs @@ -0,0 +1,126 @@ +using System; +using System.IO; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Writers; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodePropertyWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + private readonly CodeProperty property; + private readonly CodeClass parentClass; + private const string PropertyName = "propertyName"; + private const string TypeName = "Somecustomtype"; + public CodePropertyWriterTests() + { + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + tw = new StringWriter(); + writer.SetTextWriter(tw); + var root = CodeNamespace.InitRootNamespace(); + parentClass = new CodeClass + { + Name = "parentClass" + }; + root.AddClass(parentClass); + property = new CodeProperty + { + Name = PropertyName, + Type = new CodeType + { + Name = TypeName + } + }; + parentClass.AddProperty(property, new() + { + Name = "pathParameters", + Kind = CodePropertyKind.PathParameters, + Type = new CodeType + { + Name = "PathParameters", + }, + }, new() + { + Name = "requestAdapter", + Kind = CodePropertyKind.RequestAdapter, + Type = new CodeType + { + Name = "RequestAdapter", + }, + }); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + [Fact] + public void WritesCustomProperty() + { + property.Kind = CodePropertyKind.Custom; + writer.Write(property); + var result = tw.ToString(); + Assert.Contains("pub", result); + Assert.Contains("property_name", result); + } + [Fact] + public void WritesSerdeRenameAttribute() + { + property.Kind = CodePropertyKind.Custom; + property.SerializationName = "PropertyName"; + writer.Write(property); + var result = tw.ToString(); + Assert.Contains("#[serde(rename = \"PropertyName\")]", result); + } + [Fact] + public void WritesRequestBuilderAsEmpty() + { + // In Rust, request builder properties are accessed via methods, not struct fields + property.Kind = CodePropertyKind.RequestBuilder; + writer.Write(property); + var result = tw.ToString(); + // No field is written for request builder properties + Assert.DoesNotContain("pub", result); + } + [Fact] + public void DoesntWritePropertiesExistingInParentType() + { + parentClass.AddProperty(new CodeProperty + { + Name = "definedInParent", + Type = new CodeType + { + Name = "string" + }, + Kind = CodePropertyKind.Custom, + }); + var subClass = (parentClass.Parent as CodeNamespace)!.AddClass(new CodeClass + { + Name = "BaseClass", + }).First(); + subClass.StartBlock.Inherits = new CodeType + { + Name = "BaseClass", + TypeDefinition = parentClass + }; + var propertyToWrite = subClass.AddProperty(new CodeProperty + { + Name = "definedInParent", + Type = new CodeType + { + Name = "string" + }, + Kind = CodePropertyKind.Custom, + }).First(); + writer.Write(propertyToWrite); + var result = tw.ToString(); + Assert.Empty(result); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/RustWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/RustWriterTests.cs new file mode 100644 index 0000000000..8dc5e21d24 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/RustWriterTests.cs @@ -0,0 +1,19 @@ +using System; + +using Kiota.Builder.Writers.Rust; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public class RustWriterTests +{ + [Fact] + public void Instantiates() + { + var writer = new RustWriter("./", "graph"); + Assert.NotNull(writer.PathSegmenter); + Assert.Throws(() => new RustWriter(null, "graph")); + Assert.Throws(() => new RustWriter("./", null)); + } +}