diff --git a/src/Kiota.Builder/Configuration/GenerationConfiguration.cs b/src/Kiota.Builder/Configuration/GenerationConfiguration.cs index e82f1079d2..4ffc1b981a 100644 --- a/src/Kiota.Builder/Configuration/GenerationConfiguration.cs +++ b/src/Kiota.Builder/Configuration/GenerationConfiguration.cs @@ -107,6 +107,7 @@ public bool ShouldWriteBarrelsIfClassExists } private static readonly HashSet BarreledLanguages = [ GenerationLanguage.Ruby, + GenerationLanguage.Rust, ]; private static readonly HashSet BarreledLanguagesWithConstantFileName = []; public bool CleanOutput diff --git a/src/Kiota.Builder/GenerationLanguage.cs b/src/Kiota.Builder/GenerationLanguage.cs index 7f2696e7e3..97eb665571 100644 --- a/src/Kiota.Builder/GenerationLanguage.cs +++ b/src/Kiota.Builder/GenerationLanguage.cs @@ -10,5 +10,6 @@ public enum GenerationLanguage Go, Ruby, Dart, - HTTP + HTTP, + Rust } diff --git a/src/Kiota.Builder/PathSegmenters/RustPathSegmenter.cs b/src/Kiota.Builder/PathSegmenters/RustPathSegmenter.cs new file mode 100644 index 0000000000..60d256802a --- /dev/null +++ b/src/Kiota.Builder/PathSegmenters/RustPathSegmenter.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.PathSegmenters; + +public class RustPathSegmenter(string rootPath, string clientNamespaceName) : CommonPathSegmenter(rootPath, clientNamespaceName) +{ + public override string FileSuffix => ".rs"; + public override IEnumerable GetAdditionalSegment(CodeElement currentElement, string fileName) + { + if (currentElement is CodeNamespace ns && IsRootNamespace(ns)) + return Enumerable.Empty(); // lib.rs at output root, no subdirectory + + return currentElement switch + { + CodeNamespace => new[] { GetLastFileNameSegment(currentElement) }, + _ => Enumerable.Empty(), + }; + } + + public override string NormalizeFileName(CodeElement currentElement) + { + if (currentElement is CodeNamespace ns && IsRootNamespace(ns)) + return "lib"; // root namespace becomes lib.rs + + return currentElement switch + { + CodeNamespace => "mod", + _ => GetLastFileNameSegment(currentElement).ToSnakeCase(), + }; + } + + public override string NormalizeNamespaceSegment(string segmentName) => segmentName?.ToSnakeCase() ?? string.Empty; + + private bool IsRootNamespace(CodeNamespace ns) + { + return ns.Name.Equals(ClientNamespaceName, StringComparison.OrdinalIgnoreCase); + } +} diff --git a/src/Kiota.Builder/Refiners/ILanguageRefiner.cs b/src/Kiota.Builder/Refiners/ILanguageRefiner.cs index 096d7a52f4..5c7fe54b8a 100644 --- a/src/Kiota.Builder/Refiners/ILanguageRefiner.cs +++ b/src/Kiota.Builder/Refiners/ILanguageRefiner.cs @@ -41,6 +41,9 @@ public static async Task RefineAsync(GenerationConfiguration config, CodeNamespa case GenerationLanguage.Dart: await new DartRefiner(config).RefineAsync(generatedCode, cancellationToken).ConfigureAwait(false); break; + case GenerationLanguage.Rust: + await new RustRefiner(config).RefineAsync(generatedCode, cancellationToken).ConfigureAwait(false); + break; } } } diff --git a/src/Kiota.Builder/Refiners/RustExceptionsReservedNamesProvider.cs b/src/Kiota.Builder/Refiners/RustExceptionsReservedNamesProvider.cs new file mode 100644 index 0000000000..4d8e733c16 --- /dev/null +++ b/src/Kiota.Builder/Refiners/RustExceptionsReservedNamesProvider.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; + +namespace Kiota.Builder.Refiners; + +public class RustExceptionsReservedNamesProvider : IReservedNamesProvider +{ + private readonly Lazy> _reservedNames = new(static () => new(StringComparer.OrdinalIgnoreCase) + { + "error", + "source", + "description", + }); + public HashSet ReservedNames => _reservedNames.Value; +} diff --git a/src/Kiota.Builder/Refiners/RustRefiner.cs b/src/Kiota.Builder/Refiners/RustRefiner.cs new file mode 100644 index 0000000000..e984b960d3 --- /dev/null +++ b/src/Kiota.Builder/Refiners/RustRefiner.cs @@ -0,0 +1,271 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Configuration; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Refiners; + +public class RustRefiner : CommonLanguageRefiner, ILanguageRefiner +{ + private const string AbstractionsNamespaceName = "kiota_abstractions"; + private const string SerializationNamespaceName = "kiota_abstractions::serialization"; + + public RustRefiner(GenerationConfiguration configuration) : base(configuration) { } + + public override Task RefineAsync(CodeNamespace generatedCode, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + _configuration.NamespaceNameSeparator = "::"; + return Task.Run(() => + { + cancellationToken.ThrowIfCancellationRequested(); + + DeduplicateErrorMappings(generatedCode); + MoveRequestBuilderPropertiesToBaseType(generatedCode, + new CodeUsing + { + Name = "BaseRequestBuilder", + Declaration = new CodeType + { + Name = AbstractionsNamespaceName, + IsExternal = true + } + }, + accessModifier: AccessModifier.Public); + ReplaceIndexersByMethodsWithParameter( + generatedCode, + false, + static x => $"by_{x.ToSnakeCase()}", + static x => x.ToSnakeCase(), + GenerationLanguage.Rust); + cancellationToken.ThrowIfCancellationRequested(); + + AddInnerClasses(generatedCode, true, string.Empty, false); + cancellationToken.ThrowIfCancellationRequested(); + + RemoveRequestConfigurationClasses(generatedCode, + new CodeUsing + { + Name = "RequestConfiguration", + Declaration = new CodeType { Name = AbstractionsNamespaceName, IsExternal = true } + }, + new CodeType { Name = "DefaultQueryParameters", IsExternal = true }); + RemoveCancellationParameter(generatedCode); + cancellationToken.ThrowIfCancellationRequested(); + + ConvertUnionTypesToWrapper( + generatedCode, + _configuration.UsesBackingStore, + static s => s.ToSnakeCase(), + true, + string.Empty, + string.Empty, + "is_composed_type" + ); + PromoteComposedTypesToNamespace(generatedCode); + cancellationToken.ThrowIfCancellationRequested(); + + ReplaceReservedNames( + generatedCode, + new RustReservedNamesProvider(), + x => $"r#{x}", + shouldReplaceCallback: static x => x is not CodeEnumOption && x is not CodeEnum); + ReplaceReservedExceptionPropertyNames( + generatedCode, + new RustExceptionsReservedNamesProvider(), + static x => $"{x}_prop"); + + AddPropertiesAndMethodTypesImports(generatedCode, true, false, true); + AddDefaultImports(generatedCode, defaultUsingEvaluators); + cancellationToken.ThrowIfCancellationRequested(); + + CorrectCoreType(generatedCode, CorrectMethodType, CorrectPropertyType, CorrectImplements); + DisableActionOf(generatedCode, CodeParameterKind.RequestConfiguration); + + AddGetterAndSetterMethods(generatedCode, + new() + { + CodePropertyKind.Custom, + CodePropertyKind.AdditionalData, + CodePropertyKind.BackingStore, + }, + static (_, s) => s.ToSnakeCase(), + _configuration.UsesBackingStore, + false, "get_", "set_"); + AddConstructorsForDefaultValues(generatedCode, true, true); + MakeModelPropertiesNullable(generatedCode); + cancellationToken.ThrowIfCancellationRequested(); + + var defaultConfiguration = new GenerationConfiguration(); + ReplaceDefaultSerializationModules(generatedCode, defaultConfiguration.Serializers, + new(StringComparer.OrdinalIgnoreCase) + { + "kiota_serialization_json::JsonSerializationWriterFactory", + "kiota_serialization_text::TextSerializationWriterFactory", + "kiota_serialization_form::FormSerializationWriterFactory", + }); + ReplaceDefaultDeserializationModules(generatedCode, defaultConfiguration.Deserializers, + new(StringComparer.OrdinalIgnoreCase) + { + "kiota_serialization_json::JsonParseNodeFactory", + "kiota_serialization_text::TextParseNodeFactory", + "kiota_serialization_form::FormParseNodeFactory", + }); + AddParentClassToErrorClasses(generatedCode, "ApiError", AbstractionsNamespaceName); + AddDiscriminatorMappingsUsingsToParentClasses(generatedCode, "ParseNode", true); + AddParsableImplementsForModelClasses(generatedCode, "Parsable"); + cancellationToken.ThrowIfCancellationRequested(); + + ReplacePropertyNames(generatedCode, + new() + { + CodePropertyKind.Custom, + CodePropertyKind.QueryParameter, + }, + static s => s.ToSnakeCase()); + AddPrimaryErrorMessage(generatedCode, "error_message", + () => new CodeType { Name = "String", IsNullable = false, IsExternal = true }); + NormalizeEnumNames(generatedCode); + }, cancellationToken); + } + + /// Promotes model-kind classes (composed type wrappers) out of request builder classes + /// and into the request builder's parent namespace so they become separate .rs files. + private static void PromoteComposedTypesToNamespace(CodeElement currentElement) + { + if (currentElement is CodeClass parentClass && parentClass.IsOfKind(CodeClassKind.RequestBuilder)) + { + var parentNamespace = parentClass.GetImmediateParentOfType(); + if (parentNamespace != null) + { + var toPromote = parentClass.InnerClasses + .Where(static c => !c.IsOfKind(CodeClassKind.QueryParameters, CodeClassKind.RequestConfiguration, CodeClassKind.ParameterSet)) + .ToList(); + foreach (var inner in toPromote) + { + parentClass.RemoveChildElementByName(inner.Name); + inner.Parent = parentNamespace; + parentNamespace.AddClass(inner); + } + } + } + CrawlTree(currentElement, PromoteComposedTypesToNamespace); + } + + /// Normalize enum names: "Order_status" -> "OrderStatus" + private static void NormalizeEnumNames(CodeElement currentElement) + { + if (currentElement is CodeEnum codeEnum) + { + var newName = string.Join("", codeEnum.Name.Split('_').Select(static s => s.ToFirstCharacterUpperCase())); + if (!newName.Equals(codeEnum.Name, StringComparison.Ordinal)) + codeEnum.Name = newName; + } + CrawlTree(currentElement, NormalizeEnumNames); + } + + private static readonly AdditionalUsingEvaluator[] defaultUsingEvaluators = + [ + new(static x => x is CodeProperty prop && prop.IsOfKind(CodePropertyKind.RequestAdapter), + AbstractionsNamespaceName, "RequestAdapter"), + new(static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.RequestGenerator), + AbstractionsNamespaceName, "RequestInformation", "HttpMethod", "RequestOption"), + new(static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.Serializer), + SerializationNamespaceName, "SerializationWriter"), + new(static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.Deserializer, CodeMethodKind.Factory), + SerializationNamespaceName, "ParseNode", "Parsable"), + new(static x => x is CodeClass cls && cls.IsOfKind(CodeClassKind.Model), + SerializationNamespaceName, "Parsable"), + new(static x => x is CodeClass cls && cls.IsOfKind(CodeClassKind.Model) && + cls.Properties.Any(static p => p.IsOfKind(CodePropertyKind.AdditionalData)), + SerializationNamespaceName, "AdditionalDataHolder"), + new(static x => x is CodeProperty prop && prop.IsOfKind(CodePropertyKind.Headers), + AbstractionsNamespaceName, "RequestHeaders"), + ]; + + private static void CorrectMethodType(CodeMethod currentMethod) + { + if (currentMethod.IsOfKind(CodeMethodKind.Serializer)) + { + currentMethod.Parameters + .Where(static x => x.Type.Name.StartsWith('I')) + .ToList() + .ForEach(static x => x.Type.Name = x.Type.Name[1..]); + } + else if (currentMethod.IsOfKind(CodeMethodKind.Deserializer)) + { + currentMethod.ReturnType.Name = "FieldDeserializers"; + currentMethod.Name = "get_field_deserializers"; + } + else if (currentMethod.IsOfKind(CodeMethodKind.Factory)) + { + currentMethod.Parameters + .Where(static x => x.IsOfKind(CodeParameterKind.ParseNode) && x.Type.Name.StartsWith('I')) + .ToList() + .ForEach(static x => x.Type.Name = x.Type.Name[1..]); + } + else if (currentMethod.IsOfKind(CodeMethodKind.ClientConstructor, CodeMethodKind.Constructor, CodeMethodKind.RawUrlConstructor)) + { + currentMethod.Parameters + .Where(static x => x.IsOfKind(CodeParameterKind.RequestAdapter) && x.Type.Name.StartsWith('I')) + .ToList() + .ForEach(static x => x.Type.Name = x.Type.Name[1..]); + + if (currentMethod.Parameters.OfKind(CodeParameterKind.PathParameters) is CodeParameter pathsParam) + { + pathsParam.Type.Name = "HashMap"; + pathsParam.Type.IsNullable = true; + } + } + + currentMethod.Parameters + .ToList() + .ForEach(static x => x.Name = x.Name.ToFirstCharacterLowerCase()); + } + + private static void CorrectPropertyType(CodeProperty currentProperty) + { + if (currentProperty.IsOfKind(CodePropertyKind.RequestAdapter)) + { + if (currentProperty.Type.Name.StartsWith('I')) + currentProperty.Type.Name = currentProperty.Type.Name[1..]; + } + else if (currentProperty.IsOfKind(CodePropertyKind.AdditionalData)) + { + currentProperty.Type.Name = "HashMap"; + currentProperty.DefaultValue = "HashMap::new()"; + } + else if (currentProperty.IsOfKind(CodePropertyKind.PathParameters)) + { + currentProperty.Type.IsNullable = true; + currentProperty.Type.Name = "HashMap"; + if (!string.IsNullOrEmpty(currentProperty.DefaultValue)) + currentProperty.DefaultValue = "HashMap::new()"; + } + else if (currentProperty.IsOfKind(CodePropertyKind.Headers)) + { + currentProperty.DefaultValue = "RequestHeaders::new()"; + } + else if (currentProperty.IsOfKind(CodePropertyKind.Options)) + { + currentProperty.Type.IsNullable = false; + currentProperty.Type.Name = "Vec>"; + } + else if (currentProperty.IsOfKind(CodePropertyKind.BackingStore)) + { + if (currentProperty.Type.Name.StartsWith('I')) + currentProperty.Type.Name = currentProperty.Type.Name[1..]; + } + } + + private void CorrectImplements(ProprietableBlockDeclaration block) + { + block.ReplaceImplementByName(KiotaBuilder.AdditionalHolderInterface, "AdditionalDataHolder"); + block.ReplaceImplementByName(KiotaBuilder.BackedModelInterface, "BackedModel"); + } +} diff --git a/src/Kiota.Builder/Refiners/RustReservedNamesProvider.cs b/src/Kiota.Builder/Refiners/RustReservedNamesProvider.cs new file mode 100644 index 0000000000..69dedf0caa --- /dev/null +++ b/src/Kiota.Builder/Refiners/RustReservedNamesProvider.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; + +namespace Kiota.Builder.Refiners; + +public class RustReservedNamesProvider : IReservedNamesProvider +{ + private readonly Lazy> _reservedNames = new(static () => new(StringComparer.OrdinalIgnoreCase) { + // Strict keywords + "as", "break", "const", "continue", "crate", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", + "mod", "move", "mut", "pub", "ref", "return", "self", "Self", + "static", "struct", "super", "trait", "true", "type", "unsafe", + "use", "where", "while", + // Async/await + "async", "await", "dyn", + // Reserved for future use + "abstract", "become", "box", "do", "final", "macro", "override", + "priv", "try", "typeof", "unsized", "virtual", "yield", + }); + public HashSet ReservedNames => _reservedNames.Value; +} diff --git a/src/Kiota.Builder/Writers/LanguageWriter.cs b/src/Kiota.Builder/Writers/LanguageWriter.cs index 22313267dd..18917d1e2d 100644 --- a/src/Kiota.Builder/Writers/LanguageWriter.cs +++ b/src/Kiota.Builder/Writers/LanguageWriter.cs @@ -13,6 +13,7 @@ using Kiota.Builder.Writers.Php; using Kiota.Builder.Writers.Python; using Kiota.Builder.Writers.Ruby; +using Kiota.Builder.Writers.Rust; using Kiota.Builder.Writers.TypeScript; namespace Kiota.Builder.Writers; @@ -191,6 +192,7 @@ public static LanguageWriter GetLanguageWriter(GenerationLanguage language, stri GenerationLanguage.Go => new GoWriter(outputPath, clientNamespaceName, excludeBackwardCompatible), GenerationLanguage.Dart => new DartWriter(outputPath, clientNamespaceName), GenerationLanguage.HTTP => new HttpWriter(outputPath, clientNamespaceName), + GenerationLanguage.Rust => new RustWriter(outputPath, clientNamespaceName), _ => throw new InvalidEnumArgumentException($"{language} language currently not supported."), }; } diff --git a/src/Kiota.Builder/Writers/Rust/CodeBlockEndWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeBlockEndWriter.cs new file mode 100644 index 0000000000..8ee8a632f5 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeBlockEndWriter.cs @@ -0,0 +1,22 @@ +using System.Linq; +using Kiota.Builder.CodeDOM; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeBlockEndWriter : ICodeElementWriter +{ + public void WriteCodeElement(BlockEnd codeElement, LanguageWriter writer) + { + if (codeElement?.Parent is CodeNamespace) return; + if (codeElement?.Parent is CodeEnum) return; + if (codeElement?.Parent is CodeClass cls) + { + // skip nested classes (query params, config their parent impl is closed by the nested writer) + if (cls.Parent is CodeClass) return; + // if this class has query param children, the impl was already closed by the query param writer + if (cls.InnerClasses.Any(static c => c.IsOfKind(CodeClassKind.QueryParameters))) + return; + } + writer?.CloseBlock(); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeClassDeclarationWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeClassDeclarationWriter.cs new file mode 100644 index 0000000000..d1d3b719a9 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeClassDeclarationWriter.cs @@ -0,0 +1,499 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeClassDeclarationWriter(RustConventionService conventionService) : BaseElementWriter(conventionService) +{ + public override void WriteCodeElement(ClassDeclaration codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + + if (codeElement.Parent is not CodeClass parentClass) + throw new InvalidOperationException("ClassDeclaration parent is not a CodeClass"); + + // nested classes: config/query params remain after composed types were promoted + if (parentClass.Parent is CodeClass) + { + if (parentClass.IsOfKind(CodeClassKind.QueryParameters)) + { + WriteQueryParametersClass(parentClass, writer); + return; + } + // skip config classes and others + return; + } + + var className = parentClass.Name.ToFirstCharacterUpperCase(); + var isModel = parentClass.IsOfKind(CodeClassKind.Model); + var isRequestBuilder = parentClass.IsOfKind(CodeClassKind.RequestBuilder); + var hasBase = isRequestBuilder && codeElement.Inherits?.AllTypes?.Any() == true; + + // figure out the root client namespace so we can strip it from import paths + if (_clientNamespace == null) + { + var ns = parentClass.Parent as CodeNamespace; + while (ns?.Parent is CodeNamespace parent && !string.IsNullOrEmpty(parent.Name)) + ns = parent; + _clientNamespace = ns?.Name ?? string.Empty; + } + + // file header + imports + writer.WriteLine("// Code generated by Microsoft Kiota - DO NOT EDIT."); + writer.WriteLine(); + WriteImports(parentClass, isModel, isRequestBuilder, writer); + + // struct definition + conventions.WriteShortDescription(parentClass, writer); + writer.WriteLine("#[derive(Debug, Clone, Default, PartialEq)]"); + writer.WriteLine($"pub struct {className} {{"); + writer.IncreaseIndent(); + + if (hasBase) + writer.WriteLine("pub base: BaseRequestBuilder,"); + + foreach (var prop in parentClass.Properties + .Where(p => p.Kind is not (CodePropertyKind.RequestBuilder or CodePropertyKind.AdditionalData)) + .Where(p => !hasBase || !IsBaseProperty(p)) + .OrderBy(static p => p.Name, StringComparer.OrdinalIgnoreCase)) + { + var propType = conventions.GetTypeString(prop.Type, prop); + var propName = prop.Name.ToSnakeCase(); + conventions.WriteShortDescription(prop, writer); + writer.WriteLine($"pub {propName}: {propType},"); + } + + // additional data uses a fixed type + if (parentClass.GetPropertyOfKind(CodePropertyKind.AdditionalData) is not null) + writer.WriteLine("pub additional_data: HashMap,"); + + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + // Parsable trait impl for model classes + if (isModel) + WriteParsableImpl(parentClass, className, writer); + + // regular impl block for methods + writer.WriteLine($"impl {className} {{"); + writer.IncreaseIndent(); + } + + private void WriteParsableImpl(CodeClass parentClass, string className, LanguageWriter writer) + { + var customProps = parentClass.GetPropertiesOfKind(CodePropertyKind.Custom) + .OrderBy(static p => p.Name, StringComparer.OrdinalIgnoreCase) + .ToList(); + + writer.WriteLine($"impl Parsable for {className} {{"); + writer.IncreaseIndent(); + + // field_names + writer.WriteLine("fn field_names(&self) -> Vec<&'static str> {"); + writer.IncreaseIndent(); + var names = string.Join(", ", customProps.Select(p => + { + var wn = p.WireName; + return $"\"{(string.IsNullOrEmpty(wn) ? p.Name : wn)}\""; + })); + writer.WriteLine($"vec![{names}]"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + // assign_field + writer.WriteLine("fn assign_field(&mut self, field: &str, node: &dyn ParseNode) -> Result<(), KiotaError> {"); + writer.IncreaseIndent(); + writer.WriteLine("match field {"); + writer.IncreaseIndent(); + foreach (var prop in customProps) + { + var wireName = prop.WireName; + if (string.IsNullOrEmpty(wireName)) + wireName = prop.Name; + var propName = prop.Name.ToSnakeCase(); + + var isCollection = prop.Type.CollectionKind != CodeTypeBase.CodeTypeCollectionKind.None; + + if (isCollection) + { + var innerType = prop.Type is CodeType collType ? collType : null; + var innerTypeDef = innerType?.TypeDefinition; + if (innerTypeDef is CodeClass innerClass) + { + // collection of objects + var innerTypeName = innerClass.Name.ToFirstCharacterUpperCase(); + writer.WriteLine($"\"{wireName}\" => {{"); + writer.IncreaseIndent(); + writer.WriteLine($"let mut items = Vec::new();"); + writer.WriteLine($"for child in node.get_child_nodes()? {{"); + writer.IncreaseIndent(); + writer.WriteLine($"let mut obj = {innerTypeName}::default();"); + writer.WriteLine($"for f in obj.field_names().to_vec() {{"); + writer.IncreaseIndent(); + writer.WriteLine($"if let Ok(Some(n)) = child.get_child_node(f) {{ obj.assign_field(f, n.as_ref())?; }}"); + writer.DecreaseIndent(); + writer.WriteLine($"}}"); + writer.WriteLine($"items.push(obj);"); + writer.DecreaseIndent(); + writer.WriteLine($"}}"); + writer.WriteLine($"self.{propName} = items;"); + writer.DecreaseIndent(); + writer.WriteLine($"}}"); + } + else + { + // collection of primitives (strings, ints, etc.) + writer.WriteLine($"\"{wireName}\" => self.{propName} = node.get_collection_of_string_values()?,"); + } + } + else if (prop.Type is CodeType ct && ct.TypeDefinition is CodeEnum enumDef) + { + var enumType = enumDef.Name.ToFirstCharacterUpperCase(); + writer.WriteLine($"\"{wireName}\" => self.{propName} = node.get_string_value()?.and_then(|s| {enumType}::parse(&s)),"); + } + else if (prop.Type is CodeType ct2 && ct2.TypeDefinition is CodeClass objClass) + { + // nested object + var objTypeName = objClass.Name.ToFirstCharacterUpperCase(); + writer.WriteLine($"\"{wireName}\" => {{"); + writer.IncreaseIndent(); + writer.WriteLine($"let mut obj = {objTypeName}::default();"); + writer.WriteLine($"for f in obj.field_names().to_vec() {{"); + writer.IncreaseIndent(); + writer.WriteLine($"if let Ok(Some(n)) = node.get_child_node(f) {{ obj.assign_field(f, n.as_ref())?; }}"); + writer.DecreaseIndent(); + writer.WriteLine($"}}"); + writer.WriteLine($"self.{propName} = Some(obj);"); + writer.DecreaseIndent(); + writer.WriteLine($"}}"); + } + else + { + var rawType = prop.Type is CodeType ct3 ? conventions.TranslateType(ct3) : "String"; + var readMethod = GetReadMethodForType(rawType); + writer.WriteLine($"\"{wireName}\" => self.{propName} = node.{readMethod}()?,"); + } + } + writer.WriteLine("_ => {}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine("Ok(())"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + // serialize + writer.WriteLine("fn serialize(&self, writer: &mut dyn SerializationWriter) -> Result<(), KiotaError> {"); + writer.IncreaseIndent(); + foreach (var prop in customProps) + { + var wireName = prop.WireName; + if (string.IsNullOrEmpty(wireName)) + wireName = prop.Name; + var propName = prop.Name.ToSnakeCase(); + var rawType = prop.Type is CodeType ct ? conventions.TranslateType(ct) : "String"; + var isEnum = prop.Type is CodeType ct2 && ct2.TypeDefinition is CodeEnum; + var writeMethod = GetWriteMethodForType(rawType); + var needsRef = rawType == "String" || rawType.StartsWith("chrono::", StringComparison.Ordinal) + || rawType.StartsWith("uuid::", StringComparison.Ordinal) || rawType == "IsoDuration"; + + var isCollection = prop.Type.CollectionKind != CodeTypeBase.CodeTypeCollectionKind.None; + + if (isCollection) + { + // collection fields are never Option, they default to empty vec + if (RustConventionService.IsPrimitiveType(rawType)) + writer.WriteLine($"writer.write_collection_of_string_values(Some(\"{wireName}\"), &self.{propName}.iter().map(|v| v.to_string()).collect::>())?;"); + else + writer.WriteLine($"// TODO: serialize collection of objects for {propName}"); + } + else if (isEnum) + { + // enums implement Display, serialize as string + if (prop.Type.IsNullable) + writer.WriteLine($"if let Some(ref val) = self.{propName} {{ writer.write_string_value(Some(\"{wireName}\"), &val.to_string())?; }}"); + else + writer.WriteLine($"writer.write_string_value(Some(\"{wireName}\"), &self.{propName}.to_string())?;"); + } + else if (prop.Type.IsNullable && RustConventionService.IsPrimitiveType(rawType)) + { + if (needsRef) + writer.WriteLine($"if let Some(ref val) = self.{propName} {{ writer.{writeMethod}(Some(\"{wireName}\"), val)?; }}"); + else + writer.WriteLine($"if let Some(val) = self.{propName} {{ writer.{writeMethod}(Some(\"{wireName}\"), val)?; }}"); + } + else if (RustConventionService.IsPrimitiveType(rawType)) + { + writer.WriteLine($"writer.{writeMethod}(Some(\"{wireName}\"), self.{propName}.clone())?;"); + } + else if (prop.Type.IsNullable) + { + writer.WriteLine($"if let Some(ref val) = self.{propName} {{ writer.write_object_value(Some(\"{wireName}\"), val as &dyn Parsable, &[])?; }}"); + } + else + { + writer.WriteLine($"writer.write_object_value(Some(\"{wireName}\"), &self.{propName} as &dyn Parsable, &[])?;"); + } + } + + if (parentClass.GetPropertyOfKind(CodePropertyKind.AdditionalData) is not null) + writer.WriteLine("writer.write_additional_data(&self.additional_data)?;"); + + writer.WriteLine("Ok(())"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + // as_any for downcasting + writer.WriteLine("fn as_any(self: Box) -> Box {"); + writer.IncreaseIndent(); + writer.WriteLine("self"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + } + + private static string GetReadMethodForType(string typeName) + { + return typeName switch + { + "String" => "get_string_value", + "bool" => "get_bool_value", + "i32" => "get_i32_value", + "i64" => "get_i64_value", + "f32" => "get_f32_value", + "f64" => "get_f64_value", + "u8" => "get_u8_value", + "i8" => "get_i8_value", + "uuid::Uuid" => "get_uuid_value", + "Vec" => "get_byte_array_value", + "chrono::DateTime" => "get_date_time_value", + "chrono::NaiveDate" => "get_date_only_value", + "chrono::NaiveTime" => "get_time_only_value", + "IsoDuration" => "get_duration_value", + _ => "get_string_value", + }; + } + + private static string GetWriteMethodForType(string typeName) + { + return typeName switch + { + "String" => "write_string_value", + "bool" => "write_bool_value", + "i32" => "write_i32_value", + "i64" => "write_i64_value", + "f32" => "write_f32_value", + "f64" => "write_f64_value", + "u8" => "write_u8_value", + "i8" => "write_i8_value", + "uuid::Uuid" => "write_uuid_value", + "Vec" => "write_byte_array_value", + "chrono::DateTime" => "write_date_time_value", + "chrono::NaiveDate" => "write_date_only_value", + "chrono::NaiveTime" => "write_time_only_value", + "IsoDuration" => "write_duration_value", + _ => "write_string_value", + }; + } + + private static void WriteImports(CodeClass parentClass, bool isModel, bool isRequestBuilder, LanguageWriter writer) + { + var hasAdditionalData = parentClass.GetPropertyOfKind(CodePropertyKind.AdditionalData) is not null; + var hasExecutors = parentClass.Methods.Any(m => m.IsOfKind(CodeMethodKind.RequestExecutor)); + var hasConstructors = parentClass.Methods.Any(m => m.IsOfKind( + CodeMethodKind.Constructor, CodeMethodKind.ClientConstructor, CodeMethodKind.RawUrlConstructor)); + var hasGenerators = parentClass.Methods.Any(m => m.IsOfKind(CodeMethodKind.RequestGenerator)); + + if (isModel && hasAdditionalData) + writer.WriteLine("use std::collections::HashMap;"); + if (isRequestBuilder) + { + var abstractions = new List { "BaseRequestBuilder" }; + if (hasConstructors) + abstractions.Add("RequestAdapter"); + if (hasExecutors) + { + abstractions.Add("KiotaError"); + abstractions.Add("Parsable"); + abstractions.Add("ParseNode"); + } + if (hasGenerators) + { + abstractions.Add("DefaultQueryParameters"); + abstractions.Add("HttpMethod"); + abstractions.Add("KiotaError"); + abstractions.Add("RequestConfiguration"); + abstractions.Add("RequestInformation"); + } + // check if any method param is MultipartBody + if (parentClass.Methods.Any(m => m.Parameters.Any(p => + p.Type.Name.Equals("MultipartBody", StringComparison.OrdinalIgnoreCase)))) + abstractions.Add("MultipartBody"); + abstractions = abstractions.Distinct().OrderBy(static x => x, StringComparer.Ordinal).ToList(); + writer.WriteLine($"use kiota_abstractions::{{{string.Join(", ", abstractions)}}};"); + } + if (isModel) + { + writer.WriteLine("use kiota_abstractions::{KiotaError, Parsable, ParseNode, SerializationWriter};"); + } + + // external crate imports based on field types + var allTypes = parentClass.Properties + .Where(static p => p.Type is CodeType) + .Select(static p => ((CodeType)p.Type).Name.ToLowerInvariant()) + .ToHashSet(StringComparer.OrdinalIgnoreCase); + if (allTypes.Any(t => t.Contains("datetime", StringComparison.OrdinalIgnoreCase) || t.Contains("dateonly", StringComparison.OrdinalIgnoreCase) || t.Contains("timeonly", StringComparison.OrdinalIgnoreCase))) + writer.WriteLine("use chrono;"); + if (allTypes.Contains("guid") || allTypes.Contains("uuid")) + writer.WriteLine("use uuid;"); + + // collect external type references from methods/properties, excluding self + var referencedTypes = new HashSet(StringComparer.Ordinal); + var selfName = parentClass.Name; + foreach (var method in parentClass.Methods) + { + CollectTypeRef(method.ReturnType, referencedTypes, selfName); + foreach (var param in method.Parameters) + CollectTypeRef(param.Type, referencedTypes, selfName); + } + foreach (var prop in parentClass.Properties) + CollectTypeRef(prop.Type, referencedTypes, selfName); + + foreach (var import in referencedTypes.OrderBy(static x => x, StringComparer.Ordinal)) + writer.WriteLine($"use crate::{import};"); + + writer.WriteLine(); + } + + private static readonly HashSet BuiltinTypes = new(StringComparer.OrdinalIgnoreCase) + { + "String", "bool", "i8", "u8", "i32", "i64", "f32", "f64", + "void", "Vec", "Option", "HashMap", "serde_json", + "BaseRequestBuilder", "RequestAdapter", "RequestInformation", + "RequestConfiguration", "DefaultQueryParameters", "HttpMethod", + "KiotaError", "Parsable", "ParseNode", "SerializationWriter", + "MultipartBody", + }; + + private static string? _clientNamespace; + + private static void CollectTypeRef(CodeTypeBase typeBase, HashSet refs, string selfName) + { + if (typeBase is not CodeType ct) return; + if (BuiltinTypes.Contains(ct.Name)) return; + + // handle both CodeClass and CodeEnum references + CodeElement? refElement = ct.TypeDefinition; + if (refElement is not (CodeClass or CodeEnum)) return; + if (refElement.Name.Equals(selfName, StringComparison.OrdinalIgnoreCase)) return; + + // find the containing namespace — use GetImmediateParentOfType to handle + // both direct namespace children and classes nested inside other classes + var ns = refElement.GetImmediateParentOfType()?.Name ?? string.Empty; + if (_clientNamespace != null && ns.StartsWith(_clientNamespace, StringComparison.Ordinal)) + { + ns = ns.Length > _clientNamespace.Length + ? ns[(_clientNamespace.Length + 1)..] + : string.Empty; + } + + var parts = ns.Split('.', StringSplitOptions.RemoveEmptyEntries) + .Select(s => s.ToSnakeCase()).ToList(); + var typeName = refElement.Name.ToSnakeCase(); + parts.Add(typeName); + parts.Add(refElement.Name.ToFirstCharacterUpperCase()); + var path = string.Join("::", parts); + if (!string.IsNullOrEmpty(path)) + refs.Add(path); + } + + private void WriteQueryParametersClass(CodeClass parentClass, LanguageWriter writer) + { + // close the parent request builder's impl block + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + var qpName = parentClass.Name.ToFirstCharacterUpperCase(); + var props = parentClass.Properties + .OrderBy(static p => p.Name, StringComparer.OrdinalIgnoreCase) + .ToList(); + + // struct + writer.WriteLine($"#[derive(Debug, Clone, Default)]"); + writer.WriteLine($"pub struct {qpName} {{"); + writer.IncreaseIndent(); + foreach (var prop in props) + { + conventions.WriteShortDescription(prop, writer); + var propName = prop.Name.ToSnakeCase(); + // query params are always strings or collections of strings + var isCollection = prop.Type.CollectionKind != CodeTypeBase.CodeTypeCollectionKind.None; + string propType; + if (isCollection) + propType = "Vec"; + else if (prop.Type.IsNullable) + propType = "Option"; + else + propType = "String"; + writer.WriteLine($"pub {propName}: {propType},"); + } + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + // QueryParameters trait impl + writer.WriteLine($"impl kiota_abstractions::QueryParameters for {qpName} {{"); + writer.IncreaseIndent(); + writer.WriteLine("fn to_query_parameters(&self) -> std::collections::HashMap {"); + writer.IncreaseIndent(); + writer.WriteLine("let mut params = std::collections::HashMap::new();"); + foreach (var prop in props) + { + var propName = prop.Name.ToSnakeCase(); + var wireName = !string.IsNullOrEmpty(prop.WireName) ? prop.WireName + : !string.IsNullOrEmpty(prop.SerializationName) ? prop.SerializationName + : prop.Name; + var rawType = prop.Type is CodeType ct ? conventions.TranslateType(ct) : "String"; + + var isCollection = prop.Type.CollectionKind != CodeTypeBase.CodeTypeCollectionKind.None; + if (isCollection) + { + writer.WriteLine($"if !self.{propName}.is_empty() {{ params.insert(\"{wireName}\".to_string(), self.{propName}.join(\",\")); }}"); + } + else if (prop.Type.IsNullable) + { + writer.WriteLine($"if let Some(ref v) = self.{propName} {{ params.insert(\"{wireName}\".to_string(), v.to_string()); }}"); + } + else + { + writer.WriteLine($"params.insert(\"{wireName}\".to_string(), self.{propName}.to_string());"); + } + } + writer.WriteLine("params"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private static bool IsBaseProperty(CodeProperty prop) + { + return prop.Kind is CodePropertyKind.PathParameters + or CodePropertyKind.RequestAdapter + or CodePropertyKind.UrlTemplate + or CodePropertyKind.Headers + or CodePropertyKind.Options; + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeEnumWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeEnumWriter.cs new file mode 100644 index 0000000000..510495c3ec --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeEnumWriter.cs @@ -0,0 +1,83 @@ +using System; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeEnumWriter(RustConventionService conventionService) : BaseElementWriter(conventionService) +{ + public override void WriteCodeElement(CodeEnum codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + + var enumName = codeElement.Name.ToFirstCharacterUpperCase(); + + writer.WriteLine("// Code generated by Microsoft Kiota - DO NOT EDIT."); + writer.WriteLine(); + + conventions.WriteShortDescription(codeElement, writer); + + if (codeElement.Flags) + writer.WriteLine("#[derive(Debug, Clone, PartialEq, Eq, Hash)]"); + else + writer.WriteLine("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]"); + + writer.WriteLine($"pub enum {enumName} {{"); + writer.IncreaseIndent(); + foreach (var option in codeElement.Options) + { + conventions.WriteShortDescription(option, writer); + writer.WriteLine($"{option.Name.ToFirstCharacterUpperCase()},"); + } + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + // Display impl for serialization + writer.WriteLine($"impl std::fmt::Display for {enumName} {{"); + writer.IncreaseIndent(); + writer.WriteLine("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {"); + writer.IncreaseIndent(); + writer.WriteLine("match self {"); + writer.IncreaseIndent(); + foreach (var option in codeElement.Options) + { + var wireName = option.WireName; + if (string.IsNullOrEmpty(wireName)) + wireName = option.Name; + writer.WriteLine($"Self::{option.Name.ToFirstCharacterUpperCase()} => write!(f, \"{wireName}\"),"); + } + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.WriteLine(); + + // Parse function + writer.WriteLine($"impl {enumName} {{"); + writer.IncreaseIndent(); + writer.WriteLine("pub fn parse(s: &str) -> Option {"); + writer.IncreaseIndent(); + writer.WriteLine("match s {"); + writer.IncreaseIndent(); + foreach (var option in codeElement.Options) + { + var wireName = option.WireName; + if (string.IsNullOrEmpty(wireName)) + wireName = option.Name; + writer.WriteLine($"\"{wireName}\" => Some(Self::{option.Name.ToFirstCharacterUpperCase()}),"); + } + writer.WriteLine("_ => None,"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeFileBlockEndWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeFileBlockEndWriter.cs new file mode 100644 index 0000000000..a9b420dcfe --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeFileBlockEndWriter.cs @@ -0,0 +1,11 @@ +using Kiota.Builder.CodeDOM; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeFileBlockEndWriter : ICodeElementWriter +{ + public void WriteCodeElement(CodeFileBlockEnd codeElement, LanguageWriter writer) + { + // No file-level closing needed in Rust + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeFileDeclarationWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeFileDeclarationWriter.cs new file mode 100644 index 0000000000..2e1eab85b6 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeFileDeclarationWriter.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeFileDeclarationWriter(RustConventionService conventionService) : BaseElementWriter(conventionService) +{ + public override void WriteCodeElement(CodeFileDeclaration codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + + writer.WriteLine("// Code generated by Microsoft Kiota - DO NOT EDIT."); + writer.WriteLine("// Changes may cause incorrect behavior and will be lost if the code is regenerated."); + writer.WriteLine(); + + var externalUsings = codeElement.Usings + .Where(static x => x.Declaration is { IsExternal: true }) + .Select(static x => x.Declaration!.Name) + .Distinct(StringComparer.OrdinalIgnoreCase) + .OrderBy(static x => x, StringComparer.OrdinalIgnoreCase) + .ToList(); + + foreach (var u in externalUsings) + { + writer.WriteLine($"use {u};"); + } + + var internalUsings = codeElement.Usings + .Where(static x => x.Declaration is { IsExternal: false }) + .Select(static x => x.Name) + .Distinct(StringComparer.OrdinalIgnoreCase) + .OrderBy(static x => x, StringComparer.OrdinalIgnoreCase) + .ToList(); + + if (internalUsings.Count > 0 && externalUsings.Count > 0) + writer.WriteLine(); + + foreach (var u in internalUsings) + { + writer.WriteLine($"use crate::{u.ToSnakeCase()};"); + } + + if (externalUsings.Count > 0 || internalUsings.Count > 0) + writer.WriteLine(); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeMethodWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeMethodWriter.cs new file mode 100644 index 0000000000..19d156a6c5 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeMethodWriter.cs @@ -0,0 +1,400 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeMethodWriter(RustConventionService conventionService) : BaseElementWriter(conventionService) +{ + public override void WriteCodeElement(CodeMethod codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + + if (codeElement.ReturnType is null) + throw new InvalidOperationException($"Method {codeElement.Name} has no return type"); + + var parentClass = codeElement.Parent as CodeClass; + + // skip methods on nested classes (config/query params — composed types were promoted by refiner) + if (parentClass?.Parent is CodeClass) + return; + + switch (codeElement.Kind) + { + case CodeMethodKind.Serializer: + case CodeMethodKind.Deserializer: + // handled by CodeClassDeclarationWriter in the Parsable impl block + break; + case CodeMethodKind.Constructor: + WriteConstructorBody(codeElement, parentClass!, writer); + break; + case CodeMethodKind.ClientConstructor: + WriteClientConstructorBody(codeElement, parentClass!, writer); + break; + case CodeMethodKind.RawUrlConstructor: + WriteRawUrlConstructorBody(parentClass!, writer); + break; + case CodeMethodKind.Factory: + WriteFactoryMethodBody(codeElement, parentClass!, writer); + break; + case CodeMethodKind.RequestGenerator: + WriteRequestGeneratorBody(codeElement, parentClass!, writer); + break; + case CodeMethodKind.RequestExecutor: + WriteRequestExecutorBody(codeElement, parentClass!, writer); + break; + case CodeMethodKind.Getter: + WriteGetterBody(codeElement, writer); + break; + case CodeMethodKind.Setter: + WriteSetterBody(codeElement, writer); + break; + case CodeMethodKind.RequestBuilderWithParameters: + case CodeMethodKind.IndexerBackwardCompatibility: + WriteIndexerBody(codeElement, writer); + break; + case CodeMethodKind.RequestBuilderBackwardCompatibility: + WriteNavBody(codeElement, writer); + break; + default: + // skip unknown method kinds + break; + } + } + + private void WriteConstructorBody(CodeMethod method, CodeClass parentClass, LanguageWriter writer) + { + if (parentClass.IsOfKind(CodeClassKind.RequestBuilder) && method.Parameters.Any()) + { + var urlTemplate = parentClass.GetPropertyOfKind(CodePropertyKind.UrlTemplate) + ?.DefaultValue?.Trim('"') ?? string.Empty; + writer.WriteLine("pub fn new(path_parameters: std::collections::HashMap, request_adapter: std::sync::Arc) -> Self {"); + writer.IncreaseIndent(); + writer.WriteLine("Self {"); + writer.IncreaseIndent(); + writer.WriteLine($"base: BaseRequestBuilder::new(request_adapter, \"{urlTemplate}\".to_string(), path_parameters),"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + else + { + writer.WriteLine("pub fn new() -> Self {"); + writer.IncreaseIndent(); + writer.WriteLine("Self::default()"); + } + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private void WriteClientConstructorBody(CodeMethod method, CodeClass parentClass, LanguageWriter writer) + { + var urlTemplate = parentClass.GetPropertyOfKind(CodePropertyKind.UrlTemplate) + ?.DefaultValue?.Trim('"') ?? "{+baseurl}"; + var baseUrl = method.BaseUrl ?? string.Empty; + + writer.WriteLine("pub fn new(request_adapter: std::sync::Arc) -> Self {"); + writer.IncreaseIndent(); + writer.WriteLine("let mut path_parameters = std::collections::HashMap::new();"); + + if (!string.IsNullOrEmpty(baseUrl)) + { + writer.WriteLine($"if request_adapter.base_url().is_empty() {{"); + writer.IncreaseIndent(); + writer.WriteLine($"path_parameters.insert(\"baseurl\".to_string(), \"{baseUrl}\".to_string());"); + writer.DecreaseIndent(); + writer.WriteLine("} else {"); + writer.IncreaseIndent(); + writer.WriteLine("path_parameters.insert(\"baseurl\".to_string(), request_adapter.base_url().to_string());"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + writer.WriteLine("Self {"); + writer.IncreaseIndent(); + writer.WriteLine($"base: BaseRequestBuilder::new(request_adapter, \"{urlTemplate}\".to_string(), path_parameters),"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private static void WriteRawUrlConstructorBody(CodeClass parentClass, LanguageWriter writer) + { + writer.WriteLine("pub fn with_url(raw_url: &str, request_adapter: std::sync::Arc) -> Self {"); + writer.IncreaseIndent(); + writer.WriteLine("let mut path_parameters = std::collections::HashMap::new();"); + writer.WriteLine("path_parameters.insert(\"request-raw-url\".to_string(), raw_url.to_string());"); + writer.WriteLine("Self {"); + writer.IncreaseIndent(); + writer.WriteLine("base: BaseRequestBuilder::new(request_adapter, String::new(), path_parameters),"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private void WriteFactoryMethodBody(CodeMethod method, CodeClass parentClass, LanguageWriter writer) + { + writer.WriteLine("pub fn create_from_discriminator_value(_parse_node: &dyn ParseNode) -> Result {"); + writer.IncreaseIndent(); + + var disc = parentClass.DiscriminatorInformation; + if (disc?.ShouldWriteDiscriminatorForInheritedType == true && + !string.IsNullOrEmpty(disc.DiscriminatorPropertyName)) + { + writer.WriteLine($"if let Ok(Some(child)) = _parse_node.get_child_node(\"{disc.DiscriminatorPropertyName}\") {{"); + writer.IncreaseIndent(); + writer.WriteLine("if let Ok(Some(val)) = child.get_string_value() {"); + writer.IncreaseIndent(); + writer.WriteLine("match val.as_str() {"); + writer.IncreaseIndent(); + foreach (var m in disc.DiscriminatorMappings.OrderBy(static x => x.Key, StringComparer.OrdinalIgnoreCase)) + { + var t = m.Value.AllTypes.FirstOrDefault()?.TypeDefinition?.Name?.ToFirstCharacterUpperCase(); + if (!string.IsNullOrEmpty(t)) + writer.WriteLine($"\"{m.Key}\" => return Ok({t}::default()),"); + } + writer.WriteLine("_ => {}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + writer.WriteLine("Ok(Self::default())"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private void WriteRequestGeneratorBody(CodeMethod method, CodeClass parentClass, LanguageWriter writer) + { + var httpMethod = method.HttpMethod?.ToString() ?? "Get"; + httpMethod = char.ToUpperInvariant(httpMethod[0]) + httpMethod[1..].ToLowerInvariant(); + var name = method.Name.ToSnakeCase(); + var configParam = method.Parameters.OfKind(CodeParameterKind.RequestConfiguration); + var bodyParam = method.Parameters.OfKind(CodeParameterKind.RequestBody); + + // find the query parameters type for this method + var qpTypeName = GetQueryParamsTypeName(configParam); + + var sig = new List { "&self" }; + if (bodyParam != null) + sig.Add($"body: &{conventions.GetTypeString(bodyParam.Type, method)}"); + if (configParam != null) + sig.Add($"config: Option<&RequestConfiguration<{qpTypeName}>>"); + + conventions.WriteShortDescription(method, writer); + writer.WriteLine($"pub fn {name}({string.Join(", ", sig)}) -> Result {{"); + writer.IncreaseIndent(); + writer.WriteLine($"let mut request_info = RequestInformation::new_with_method_and_url_template(HttpMethod::{httpMethod}, &self.base.url_template, self.base.path_parameters.clone());"); + + if (configParam != null) + { + writer.WriteLine("if let Some(c) = config {"); + writer.IncreaseIndent(); + writer.WriteLine("request_info.configure(c);"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + if (method.AcceptedResponseTypes?.Any() == true) + writer.WriteLine($"request_info.headers.try_add(\"Accept\", \"{string.Join(", ", method.AcceptedResponseTypes)}\");"); + + if (bodyParam != null) + { + var contentType = method.RequestBodyContentType ?? "application/json"; + var isBodyCollection = bodyParam.Type.CollectionKind != CodeTypeBase.CodeTypeCollectionKind.None; + if (isBodyCollection) + { + writer.WriteLine($"// TODO: serialize collection body for \"{contentType}\""); + } + else if (bodyParam.Type.IsNullable) + { + writer.WriteLine("if let Some(ref b) = body {"); + writer.IncreaseIndent(); + writer.WriteLine($"request_info.set_content_from_parsable(self.base.request_adapter.serialization_writer_factory(), \"{contentType}\", b as &dyn Parsable)?;"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + else + { + writer.WriteLine($"request_info.set_content_from_parsable(self.base.request_adapter.serialization_writer_factory(), \"{contentType}\", body as &dyn Parsable)?;"); + } + } + + writer.WriteLine("Ok(request_info)"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private void WriteRequestExecutorBody(CodeMethod method, CodeClass parentClass, LanguageWriter writer) + { + var name = method.Name.ToSnakeCase(); + var returnType = conventions.GetTypeString(method.ReturnType, method); + var configParam = method.Parameters.OfKind(CodeParameterKind.RequestConfiguration); + var bodyParam = method.Parameters.OfKind(CodeParameterKind.RequestBody); + var isVoid = method.ReturnType.Name.Equals("void", StringComparison.OrdinalIgnoreCase); + var isStream = method.ReturnType.Name.Equals("binary", StringComparison.OrdinalIgnoreCase); + var isCollection = method.ReturnType.CollectionKind != CodeTypeBase.CodeTypeCollectionKind.None; + + var generatorMethod = parentClass.Methods + .FirstOrDefault(m => m.IsOfKind(CodeMethodKind.RequestGenerator) && m.HttpMethod == method.HttpMethod); + var genName = generatorMethod?.Name.ToSnakeCase() ?? $"to_{name}_request_information"; + + var qpTypeName = GetQueryParamsTypeName(configParam); + + var sig = new List { "&self" }; + if (bodyParam != null) + sig.Add($"body: &{conventions.GetTypeString(bodyParam.Type, method)}"); + if (configParam != null) + sig.Add($"config: Option<&RequestConfiguration<{qpTypeName}>>"); + + var args = new List(); + if (bodyParam != null) args.Add("body"); + if (configParam != null) args.Add("config"); + + conventions.WriteShortDescription(method, writer); + writer.WriteLine($"pub async fn {name}({string.Join(", ", sig)}) -> Result<{returnType}, KiotaError> {{"); + writer.IncreaseIndent(); + writer.WriteLine($"let request_info = self.{genName}({string.Join(", ", args)})?;"); + + if (isVoid) + { + writer.WriteLine("self.base.request_adapter.send_no_content(&request_info, None).await"); + } + else if (isStream) + { + writer.WriteLine("self.base.request_adapter.send_primitive(&request_info, None).await"); + } + else + { + // figure out the model type for the factory + var modelType = method.ReturnType is CodeType ct && ct.TypeDefinition is CodeClass modelClass + ? modelClass.Name.ToFirstCharacterUpperCase() + : null; + + if (modelType != null) + { + writer.WriteLine($"let factory: Box Result, KiotaError> + Send + Sync> ="); + writer.IncreaseIndent(); + writer.WriteLine($"Box::new(|node| Ok(Box::new({modelType}::create_from_discriminator_value(node)?)));"); + writer.DecreaseIndent(); + + if (isCollection) + { + writer.WriteLine("let results = self.base.request_adapter.send_collection(&request_info, &factory, None).await?;"); + writer.WriteLine($"let typed: Vec<{modelType}> = results.into_iter().filter_map(|r| {{"); + writer.IncreaseIndent(); + writer.WriteLine($"r.as_any().downcast::<{modelType}>().ok().map(|b| *b)"); + writer.DecreaseIndent(); + writer.WriteLine("}).collect();"); + writer.WriteLine("Ok(typed)"); + } + else if (method.ReturnType.IsNullable) + { + writer.WriteLine("let result = self.base.request_adapter.send(&request_info, &factory, None).await?;"); + writer.WriteLine($"Ok(result.and_then(|r| r.as_any().downcast::<{modelType}>().ok().map(|b| *b)))"); + } + else + { + writer.WriteLine("let result = self.base.request_adapter.send(&request_info, &factory, None).await?;"); + writer.WriteLine($"result.ok_or_else(|| KiotaError::General(\"empty response\".to_string()))"); + writer.WriteLine($" .and_then(|r| r.as_any().downcast::<{modelType}>().map(|b| *b)"); + writer.WriteLine($" .map_err(|_| KiotaError::Deserialization(\"type mismatch\".to_string())))"); + } + } + else + { + writer.WriteLine("todo!(\"unknown return type\")"); + } + } + + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private void WriteGetterBody(CodeMethod method, LanguageWriter writer) + { + var prop = method.AccessedProperty; + if (prop == null) return; + var fieldName = prop.Name.ToSnakeCase(); + var cleanName = StripRawPrefix(fieldName); + var returnType = conventions.GetTypeString(prop.Type, method); + + writer.WriteLine($"pub fn get_{cleanName}(&self) -> &{returnType} {{"); + writer.IncreaseIndent(); + writer.WriteLine($"&self.{fieldName}"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private void WriteSetterBody(CodeMethod method, LanguageWriter writer) + { + var prop = method.AccessedProperty; + if (prop == null) return; + var fieldName = prop.Name.ToSnakeCase(); + var cleanName = StripRawPrefix(fieldName); + var paramType = conventions.GetTypeString(prop.Type, method); + + writer.WriteLine($"pub fn set_{cleanName}(&mut self, value: {paramType}) {{"); + writer.IncreaseIndent(); + writer.WriteLine($"self.{fieldName} = value;"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private static string GetQueryParamsTypeName(CodeParameter? configParam) + { + if (configParam?.Type is CodeType configType) + { + var genericArg = configType.GenericTypeParameterValues.FirstOrDefault(); + if (genericArg is CodeType qpType && qpType.TypeDefinition is CodeClass qpClass) + return qpClass.Name.ToFirstCharacterUpperCase(); + } + return "DefaultQueryParameters"; + } + + private static string StripRawPrefix(string name) + { + return name.StartsWith("r#", StringComparison.Ordinal) ? name[2..] : name; + } + + private void WriteIndexerBody(CodeMethod method, LanguageWriter writer) + { + var returnType = conventions.GetTypeString(method.ReturnType, method); + var name = method.Name.ToSnakeCase(); + + var sig = new List { "&self" }; + foreach (var p in method.Parameters.Where(static p => p.IsOfKind(CodeParameterKind.Custom, CodeParameterKind.Path))) + sig.Add($"{p.Name.ToSnakeCase()}: {conventions.GetTypeString(p.Type, method)}"); + + conventions.WriteShortDescription(method, writer); + writer.WriteLine($"pub fn {name}({string.Join(", ", sig)}) -> {returnType} {{"); + writer.IncreaseIndent(); + writer.WriteLine("let mut url_tpl_params = self.base.path_parameters.clone();"); + foreach (var p in method.Parameters.Where(static p => p.IsOfKind(CodeParameterKind.Custom, CodeParameterKind.Path))) + writer.WriteLine($"url_tpl_params.insert(\"{p.SerializationName ?? p.Name}\".to_string(), {p.Name.ToSnakeCase()}.to_string());"); + writer.WriteLine($"{returnType}::new(url_tpl_params, self.base.request_adapter.clone())"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + + private void WriteNavBody(CodeMethod method, LanguageWriter writer) + { + var returnType = conventions.GetTypeString(method.ReturnType, method); + var name = method.Name.ToSnakeCase(); + + conventions.WriteShortDescription(method, writer); + writer.WriteLine($"pub fn {name}(&self) -> {returnType} {{"); + writer.IncreaseIndent(); + writer.WriteLine($"{returnType}::new(self.base.path_parameters.clone(), self.base.request_adapter.clone())"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodeNamespaceWriter.cs b/src/Kiota.Builder/Writers/Rust/CodeNamespaceWriter.cs new file mode 100644 index 0000000000..b2cb624ea7 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodeNamespaceWriter.cs @@ -0,0 +1,45 @@ +using System; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodeNamespaceWriter(RustConventionService conventionService) : BaseElementWriter(conventionService) +{ + public override void WriteCodeElement(CodeNamespace codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + + writer.WriteLine("// Code generated by Microsoft Kiota - DO NOT EDIT."); + writer.WriteLine(); + + // Declare child namespaces as submodules + foreach (var childNs in codeElement.GetChildElements(true).OfType().OrderBy(static x => x.Name, StringComparer.OrdinalIgnoreCase)) + { + var segmentName = childNs.Name.Split('.')[^1].ToSnakeCase(); + writer.WriteLine($"pub mod {segmentName};"); + } + + // Declare child classes/enums as submodules and re-export + foreach (var childEnum in codeElement.GetChildElements(true).OfType().OrderBy(static x => x.Name, StringComparer.OrdinalIgnoreCase)) + { + var modName = childEnum.Name.ToSnakeCase(); + writer.WriteLine($"pub mod {modName};"); + } + + foreach (var childClass in codeElement.GetChildElements(true).OfType().OrderBy(static x => x.Name, StringComparer.OrdinalIgnoreCase)) + { + var modName = childClass.Name.ToSnakeCase(); + writer.WriteLine($"pub mod {modName};"); + } + + foreach (var childFile in codeElement.GetChildElements(true).OfType().OrderBy(static x => x.Name, StringComparer.OrdinalIgnoreCase)) + { + var modName = childFile.Name.ToSnakeCase(); + writer.WriteLine($"pub mod {modName};"); + } + } +} diff --git a/src/Kiota.Builder/Writers/Rust/CodePropertyWriter.cs b/src/Kiota.Builder/Writers/Rust/CodePropertyWriter.cs new file mode 100644 index 0000000000..05c35f320b --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/CodePropertyWriter.cs @@ -0,0 +1,35 @@ +using System; +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class CodePropertyWriter(RustConventionService conventionService) : BaseElementWriter(conventionService) +{ + public override void WriteCodeElement(CodeProperty codeElement, LanguageWriter writer) + { + ArgumentNullException.ThrowIfNull(codeElement); + ArgumentNullException.ThrowIfNull(writer); + + if (codeElement.Kind == CodePropertyKind.RequestBuilder) + { + WriteNavigationProperty(codeElement, writer); + return; + } + // Everything else is already written by CodeClassDeclarationWriter + } + + private void WriteNavigationProperty(CodeProperty property, LanguageWriter writer) + { + // navigation properties return the builder directly, not Option<> + var returnType = property.Type is CodeType ct ? ct.Name.ToFirstCharacterUpperCase() : conventions.GetTypeString(property.Type, property); + var methodName = property.Name.ToSnakeCase(); + + conventions.WriteShortDescription(property, writer); + writer.WriteLine($"pub fn {methodName}(&self) -> {returnType} {{"); + writer.IncreaseIndent(); + writer.WriteLine($"{returnType}::new(self.base.path_parameters.clone(), self.base.request_adapter.clone())"); + writer.DecreaseIndent(); + writer.WriteLine("}"); + } +} diff --git a/src/Kiota.Builder/Writers/Rust/RustConventionService.cs b/src/Kiota.Builder/Writers/Rust/RustConventionService.cs new file mode 100644 index 0000000000..103cccf1de --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/RustConventionService.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; + +namespace Kiota.Builder.Writers.Rust; + +public class RustConventionService : CommonLanguageConventionService +{ + public override string StreamTypeName => "Vec"; + public override string VoidTypeName => "()"; + public override string DocCommentPrefix => "/// "; + public override string ParseNodeInterfaceName => "ParseNode"; + public override string TempDictionaryVarName => "url_tpl_params"; + public override string GetAccessModifier(AccessModifier access) => access switch + { + AccessModifier.Public => "pub ", + AccessModifier.Protected => "pub(crate) ", + _ => string.Empty, + }; + public override string GetParameterSignature(CodeParameter parameter, CodeElement targetElement, LanguageWriter? writer = null) + { + ArgumentNullException.ThrowIfNull(parameter); + var paramType = GetTypeString(parameter.Type, targetElement); + var paramName = parameter.Name.ToSnakeCase(); + return $"{paramName}: {paramType}"; + } + public override string GetTypeString(CodeTypeBase code, CodeElement targetElement, bool includeCollectionInformation = true, LanguageWriter? writer = null) + { + ArgumentNullException.ThrowIfNull(code); + if (code is CodeComposedTypeBase) + throw new InvalidOperationException($"Rust does not support union types directly; the union type {code.Name} should have been converted to a wrapper by the refiner"); + if (code is CodeType currentType) + { + var typeName = TranslateType(currentType); + var collectionPrefix = currentType.CollectionKind switch + { + CodeTypeBase.CodeTypeCollectionKind.Array or + CodeTypeBase.CodeTypeCollectionKind.Complex when includeCollectionInformation => "Vec<", + _ => string.Empty, + }; + var collectionSuffix = collectionPrefix.Length > 0 ? ">" : string.Empty; + var nullablePrefix = currentType.IsNullable && + currentType.CollectionKind == CodeTypeBase.CodeTypeCollectionKind.None + ? "Option<" : string.Empty; + var nullableSuffix = nullablePrefix.Length > 0 ? ">" : string.Empty; + return $"{nullablePrefix}{collectionPrefix}{typeName}{collectionSuffix}{nullableSuffix}"; + } + throw new InvalidOperationException($"type of type {code.GetType()} is not handled"); + } + public override string TranslateType(CodeType type) + { + ArgumentNullException.ThrowIfNull(type); + return type.Name.ToLowerInvariant() switch + { + "void" => "()", + "string" => "String", + "integer" or "int32" => "i32", + "int64" or "long" => "i64", + "float" or "float32" => "f32", + "double" or "float64" or "decimal" => "f64", + "byte" or "uint8" => "u8", + "sbyte" or "int8" => "i8", + "boolean" or "bool" => "bool", + "guid" or "uuid" => "uuid::Uuid", + "datetimeoffset" or "datetime" => "chrono::DateTime", + "dateonly" => "chrono::NaiveDate", + "timeonly" => "chrono::NaiveTime", + "isoduration" or "timespan" or "duration" => "IsoDuration", + "binary" or "base64" or "base64url" => "Vec", + "object" => "serde_json::Value", + "iparsenode" or "parsenode" => "dyn ParseNode", + "iserializationwriter" or "serializationwriter" => "dyn SerializationWriter", + "" or null => "serde_json::Value", + _ => type.Name.ToFirstCharacterUpperCase(), + }; + } + internal static HashSet PrimitiveTypes => new(StringComparer.OrdinalIgnoreCase) + { + "String", "bool", "i8", "u8", "i32", "i64", "f32", "f64", + "uuid::Uuid", "chrono::DateTime", "chrono::NaiveDate", "chrono::NaiveTime", "IsoDuration", + }; + public static bool IsPrimitiveType(string typeName) => PrimitiveTypes.Contains(typeName); + public override bool WriteShortDescription(IDocumentedElement element, LanguageWriter writer, string prefix = "", string suffix = "") + { + ArgumentNullException.ThrowIfNull(writer); + ArgumentNullException.ThrowIfNull(element); + if (element.Documentation is not { } documentation) return false; + var description = element.Documentation.GetDescription(static type => type.Name.ToFirstCharacterUpperCase()); + if (!string.IsNullOrEmpty(description)) + { + writer.WriteLine($"{DocCommentPrefix}{description.CleanupXMLString()}"); + return true; + } + return false; + } +} diff --git a/src/Kiota.Builder/Writers/Rust/RustWriter.cs b/src/Kiota.Builder/Writers/Rust/RustWriter.cs new file mode 100644 index 0000000000..9e881dc9e1 --- /dev/null +++ b/src/Kiota.Builder/Writers/Rust/RustWriter.cs @@ -0,0 +1,20 @@ +using Kiota.Builder.PathSegmenters; + +namespace Kiota.Builder.Writers.Rust; + +public class RustWriter : LanguageWriter +{ + public RustWriter(string rootPath, string clientNamespaceName) + { + PathSegmenter = new RustPathSegmenter(rootPath, clientNamespaceName); + var conventionService = new RustConventionService(); + AddOrReplaceCodeElementWriter(new CodeClassDeclarationWriter(conventionService)); + AddOrReplaceCodeElementWriter(new CodeBlockEndWriter()); + AddOrReplaceCodeElementWriter(new CodePropertyWriter(conventionService)); + AddOrReplaceCodeElementWriter(new CodeEnumWriter(conventionService)); + AddOrReplaceCodeElementWriter(new CodeMethodWriter(conventionService)); + AddOrReplaceCodeElementWriter(new CodeFileBlockEndWriter()); + AddOrReplaceCodeElementWriter(new CodeFileDeclarationWriter(conventionService)); + AddOrReplaceCodeElementWriter(new CodeNamespaceWriter(conventionService)); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassDeclarationWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassDeclarationWriterTests.cs new file mode 100644 index 0000000000..6e9473e0a0 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodeClassDeclarationWriterTests.cs @@ -0,0 +1,150 @@ +using System; +using System.IO; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; +using Kiota.Builder.Writers; +using Kiota.Builder.Writers.Rust; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodeClassDeclarationWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + private readonly CodeClassDeclarationWriter codeElementWriter; + private readonly CodeNamespace root; + + public CodeClassDeclarationWriterTests() + { + codeElementWriter = new CodeClassDeclarationWriter(new RustConventionService()); + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + tw = new StringWriter(); + writer.SetTextWriter(tw); + root = CodeNamespace.InitRootNamespace(); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + [Fact] + public void WritesModelStruct() + { + var modelClass = new CodeClass + { + Name = "TestModel", + Kind = CodeClassKind.Model, + }; + root.AddClass(modelClass); + modelClass.AddProperty(new CodeProperty + { + Name = "displayName", + Kind = CodePropertyKind.Custom, + Type = new CodeType { Name = "string" }, + }); + modelClass.AddProperty(new CodeProperty + { + Name = "age", + Kind = CodePropertyKind.Custom, + Type = new CodeType { Name = "integer" }, + }); + codeElementWriter.WriteCodeElement(modelClass.StartBlock, writer); + var result = tw.ToString(); + Assert.Contains("pub struct TestModel {", result); + Assert.Contains("#[derive(Debug, Clone, Default, PartialEq)]", result); + Assert.Contains("pub display_name:", result); + Assert.Contains("pub age:", result); + // Parsable impl + Assert.Contains("impl Parsable for TestModel {", result); + Assert.Contains("fn field_names(&self) -> Vec<&'static str>", result); + Assert.Contains("fn assign_field(&mut self, field: &str, node: &dyn ParseNode)", result); + Assert.Contains("fn serialize(&self, writer: &mut dyn SerializationWriter)", result); + } + [Fact] + public void WritesRequestBuilderStruct() + { + var rbClass = new CodeClass + { + Name = "UsersRequestBuilder", + Kind = CodeClassKind.RequestBuilder, + }; + root.AddClass(rbClass); + rbClass.StartBlock.Inherits = new CodeType + { + Name = "BaseRequestBuilder", + IsExternal = true, + }; + rbClass.AddProperty(new CodeProperty + { + Name = "pathParameters", + Kind = CodePropertyKind.PathParameters, + Type = new CodeType { Name = "string" }, + }); + rbClass.AddProperty(new CodeProperty + { + Name = "requestAdapter", + Kind = CodePropertyKind.RequestAdapter, + Type = new CodeType { Name = "RequestAdapter" }, + }); + rbClass.AddProperty(new CodeProperty + { + Name = "UrlTemplate", + Kind = CodePropertyKind.UrlTemplate, + Type = new CodeType { Name = "string" }, + }); + codeElementWriter.WriteCodeElement(rbClass.StartBlock, writer); + var result = tw.ToString(); + Assert.Contains("pub struct UsersRequestBuilder {", result); + Assert.Contains("pub base: BaseRequestBuilder,", result); + // base properties should not be duplicated as separate fields + Assert.DoesNotContain("pub path_parameters:", result); + Assert.DoesNotContain("pub request_adapter:", result); + } + [Fact] + public void WritesImports() + { + var modelClass = new CodeClass + { + Name = "Invoice", + Kind = CodeClassKind.Model, + }; + root.AddClass(modelClass); + modelClass.AddProperty(new CodeProperty + { + Name = "amount", + Kind = CodePropertyKind.Custom, + Type = new CodeType { Name = "string" }, + }); + codeElementWriter.WriteCodeElement(modelClass.StartBlock, writer); + var result = tw.ToString(); + Assert.Contains("use kiota_abstractions::", result); + Assert.Contains("Parsable", result); + Assert.Contains("ParseNode", result); + Assert.Contains("SerializationWriter", result); + } + [Fact] + public void WritesGeneratedCodeComment() + { + var modelClass = new CodeClass + { + Name = "Marker", + Kind = CodeClassKind.Model, + }; + root.AddClass(modelClass); + modelClass.AddProperty(new CodeProperty + { + Name = "id", + Kind = CodePropertyKind.Custom, + Type = new CodeType { Name = "string" }, + }); + codeElementWriter.WriteCodeElement(modelClass.StartBlock, writer); + var result = tw.ToString(); + Assert.Contains("DO NOT EDIT", result); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodeEnumWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodeEnumWriterTests.cs new file mode 100644 index 0000000000..1b0d41f588 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodeEnumWriterTests.cs @@ -0,0 +1,106 @@ +using System; +using System.IO; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; +using Kiota.Builder.Writers; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodeEnumWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + private readonly CodeEnum currentEnum; + private const string EnumName = "someEnum"; + public CodeEnumWriterTests() + { + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + tw = new StringWriter(); + writer.SetTextWriter(tw); + var root = CodeNamespace.InitRootNamespace(); + currentEnum = root.AddEnum(new CodeEnum + { + Name = EnumName, + }).First(); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + [Fact] + public void WritesEnum() + { + currentEnum.AddOption(new CodeEnumOption { Name = "option1" }); + currentEnum.AddOption(new CodeEnumOption { Name = "option2" }); + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Contains("pub enum SomeEnum {", result); + Assert.Contains("Option1,", result); + Assert.Contains("Option2,", result); + Assert.Contains("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]", result); + // Display impl + Assert.Contains("impl std::fmt::Display for SomeEnum {", result); + Assert.Contains("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {", result); + Assert.Contains("Self::Option1 => write!(f, \"option1\")", result); + Assert.Contains("Self::Option2 => write!(f, \"option2\")", result); + // Parse method + Assert.Contains("impl SomeEnum {", result); + Assert.Contains("pub fn parse(s: &str) -> Option {", result); + Assert.Contains("\"option1\" => Some(Self::Option1)", result); + Assert.Contains("\"option2\" => Some(Self::Option2)", result); + Assert.Contains("_ => None,", result); + } + [Fact] + public void WritesEnumWithDocComments() + { + currentEnum.Documentation = new() + { + DescriptionTemplate = "Represents the available choices", + }; + currentEnum.AddOption(new CodeEnumOption + { + Name = "active", + Documentation = new() + { + DescriptionTemplate = "The active state", + }, + }); + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Contains("/// Represents the available choices", result); + Assert.Contains("/// The active state", result); + Assert.Contains("pub enum SomeEnum {", result); + Assert.Contains("Active,", result); + } + [Fact] + public void WritesFlagEnumWithoutCopy() + { + var root = CodeNamespace.InitRootNamespace(); + var flagEnum = root.AddEnum(new CodeEnum + { + Name = "permissions", + Flags = true, + }).First(); + flagEnum.AddOption(new CodeEnumOption { Name = "read" }); + flagEnum.AddOption(new CodeEnumOption { Name = "write" }); + writer.Write(flagEnum); + var result = tw.ToString(); + Assert.Contains("#[derive(Debug, Clone, PartialEq, Eq, Hash)]", result); + Assert.DoesNotContain("Copy", result); + } + [Fact] + public void WritesGeneratedCodeComment() + { + currentEnum.AddOption(new CodeEnumOption { Name = "val" }); + writer.Write(currentEnum); + var result = tw.ToString(); + Assert.Contains("DO NOT EDIT", result); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/CodeMethodWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/CodeMethodWriterTests.cs new file mode 100644 index 0000000000..bca15ab635 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/CodeMethodWriterTests.cs @@ -0,0 +1,193 @@ +using System; +using System.IO; +using System.Linq; + +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Extensions; +using Kiota.Builder.Writers; +using Kiota.Builder.Writers.Rust; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public sealed class CodeMethodWriterTests : IDisposable +{ + private const string DefaultPath = "./"; + private const string DefaultName = "name"; + private readonly StringWriter tw; + private readonly LanguageWriter writer; + private readonly CodeNamespace root; + + public CodeMethodWriterTests() + { + writer = LanguageWriter.GetLanguageWriter(GenerationLanguage.Rust, DefaultPath, DefaultName); + tw = new StringWriter(); + writer.SetTextWriter(tw); + root = CodeNamespace.InitRootNamespace(); + } + public void Dispose() + { + tw?.Dispose(); + GC.SuppressFinalize(this); + } + [Fact] + public void WritesConstructor() + { + var modelClass = new CodeClass + { + Name = "Payment", + Kind = CodeClassKind.Model, + }; + root.AddClass(modelClass); + var method = new CodeMethod + { + Name = "constructor", + Kind = CodeMethodKind.Constructor, + ReturnType = new CodeType { Name = "void" }, + }; + modelClass.AddMethod(method); + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("pub fn new() -> Self {", result); + Assert.Contains("Self::default()", result); + } + [Fact] + public void WritesRequestBuilderConstructor() + { + var rbClass = new CodeClass + { + Name = "PaymentsRequestBuilder", + Kind = CodeClassKind.RequestBuilder, + }; + root.AddClass(rbClass); + rbClass.StartBlock.Inherits = new CodeType + { + Name = "BaseRequestBuilder", + IsExternal = true, + }; + rbClass.AddProperty(new CodeProperty + { + Name = "pathParameters", + Kind = CodePropertyKind.PathParameters, + Type = new CodeType { Name = "string" }, + }); + rbClass.AddProperty(new CodeProperty + { + Name = "requestAdapter", + Kind = CodePropertyKind.RequestAdapter, + Type = new CodeType { Name = "RequestAdapter" }, + }); + rbClass.AddProperty(new CodeProperty + { + Name = "UrlTemplate", + Kind = CodePropertyKind.UrlTemplate, + Type = new CodeType { Name = "string" }, + DefaultValue = "\"{+baseurl}/payments\"", + }); + var method = new CodeMethod + { + Name = "constructor", + Kind = CodeMethodKind.Constructor, + ReturnType = new CodeType { Name = "void" }, + }; + method.AddParameter(new CodeParameter + { + Name = "pathParameters", + Kind = CodeParameterKind.PathParameters, + Type = new CodeType { Name = "string" }, + }); + method.AddParameter(new CodeParameter + { + Name = "requestAdapter", + Kind = CodeParameterKind.RequestAdapter, + Type = new CodeType { Name = "RequestAdapter" }, + }); + rbClass.AddMethod(method); + writer.Write(method); + var result = tw.ToString(); + Assert.Contains("pub fn new(path_parameters: std::collections::HashMap, request_adapter: std::sync::Arc) -> Self {", result); + Assert.Contains("base: BaseRequestBuilder::new(request_adapter,", result); + Assert.Contains("{+baseurl}/payments", result); + } + [Fact] + public void WritesGetter() + { + var modelClass = new CodeClass + { + Name = "Invoice", + Kind = CodeClassKind.Model, + }; + root.AddClass(modelClass); + var prop = new CodeProperty + { + Name = "amount", + Kind = CodePropertyKind.Custom, + Type = new CodeType { Name = "string" }, + }; + modelClass.AddProperty(prop); + var getter = new CodeMethod + { + Name = "GetAmount", + Kind = CodeMethodKind.Getter, + ReturnType = new CodeType { Name = "string" }, + AccessedProperty = prop, + }; + modelClass.AddMethod(getter); + writer.Write(getter); + var result = tw.ToString(); + Assert.Contains("pub fn get_amount(&self) ->", result); + Assert.Contains("&self.amount", result); + } + [Fact] + public void WritesSetter() + { + var modelClass = new CodeClass + { + Name = "Invoice", + Kind = CodeClassKind.Model, + }; + root.AddClass(modelClass); + var prop = new CodeProperty + { + Name = "amount", + Kind = CodePropertyKind.Custom, + Type = new CodeType { Name = "string" }, + }; + modelClass.AddProperty(prop); + var setter = new CodeMethod + { + Name = "SetAmount", + Kind = CodeMethodKind.Setter, + ReturnType = new CodeType { Name = "void" }, + AccessedProperty = prop, + }; + modelClass.AddMethod(setter); + writer.Write(setter); + var result = tw.ToString(); + Assert.Contains("pub fn set_amount(&mut self, value:", result); + Assert.Contains("self.amount = value;", result); + } + [Fact] + public void WritesFactoryMethod() + { + var modelClass = new CodeClass + { + Name = "Payment", + Kind = CodeClassKind.Model, + }; + root.AddClass(modelClass); + var factory = new CodeMethod + { + Name = "createFromDiscriminatorValue", + Kind = CodeMethodKind.Factory, + ReturnType = new CodeType { Name = "Payment" }, + IsStatic = true, + }; + modelClass.AddMethod(factory); + writer.Write(factory); + var result = tw.ToString(); + Assert.Contains("pub fn create_from_discriminator_value(_parse_node: &dyn ParseNode) -> Result {", result); + Assert.Contains("Ok(Self::default())", result); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/RustConventionServiceTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/RustConventionServiceTests.cs new file mode 100644 index 0000000000..e0e7af5b64 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/RustConventionServiceTests.cs @@ -0,0 +1,88 @@ +using Kiota.Builder.CodeDOM; +using Kiota.Builder.Writers.Rust; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public class RustConventionServiceTests +{ + private readonly RustConventionService sut = new(); + + [Fact] + public void TranslatesStringType() + { + var codeType = new CodeType { Name = "string" }; + var result = sut.TranslateType(codeType); + Assert.Equal("String", result); + } + [Fact] + public void TranslatesIntegerType() + { + var codeType = new CodeType { Name = "integer" }; + var result = sut.TranslateType(codeType); + Assert.Equal("i32", result); + } + [Fact] + public void TranslatesBooleanType() + { + var codeType = new CodeType { Name = "boolean" }; + var result = sut.TranslateType(codeType); + Assert.Equal("bool", result); + } + [Fact] + public void TranslatesDateTimeType() + { + var codeType = new CodeType { Name = "DateTimeOffset" }; + var result = sut.TranslateType(codeType); + Assert.Equal("chrono::DateTime", result); + } + [Fact] + public void GetAccessModifierPublic() + { + var result = sut.GetAccessModifier(AccessModifier.Public); + Assert.Equal("pub ", result); + } + [Fact] + public void GetAccessModifierPrivate() + { + var result = sut.GetAccessModifier(AccessModifier.Private); + Assert.Equal(string.Empty, result); + } + [Fact] + public void TranslatesVoidType() + { + var codeType = new CodeType { Name = "void" }; + var result = sut.TranslateType(codeType); + Assert.Equal("()", result); + } + [Fact] + public void TranslatesGuidType() + { + var codeType = new CodeType { Name = "guid" }; + var result = sut.TranslateType(codeType); + Assert.Equal("uuid::Uuid", result); + } + [Fact] + public void TranslatesBinaryType() + { + var codeType = new CodeType { Name = "binary" }; + var result = sut.TranslateType(codeType); + Assert.Equal("Vec", result); + } + [Fact] + public void StreamTypeNameIsCorrect() + { + Assert.Equal("Vec", sut.StreamTypeName); + } + [Fact] + public void VoidTypeNameIsCorrect() + { + Assert.Equal("()", sut.VoidTypeName); + } + [Fact] + public void DocCommentPrefixIsCorrect() + { + Assert.Equal("/// ", sut.DocCommentPrefix); + } +} diff --git a/tests/Kiota.Builder.Tests/Writers/Rust/RustWriterTests.cs b/tests/Kiota.Builder.Tests/Writers/Rust/RustWriterTests.cs new file mode 100644 index 0000000000..55e40eb7c1 --- /dev/null +++ b/tests/Kiota.Builder.Tests/Writers/Rust/RustWriterTests.cs @@ -0,0 +1,28 @@ +using System; + +using Kiota.Builder.Writers.Rust; + +using Xunit; + +namespace Kiota.Builder.Tests.Writers.Rust; + +public class RustWriterTests +{ + [Fact] + public void WriterExists() + { + var writer = new RustWriter("./", "graph"); + Assert.NotNull(writer); + Assert.NotNull(writer.PathSegmenter); + } + [Fact] + public void ThrowsOnNullRootPath() + { + Assert.Throws(() => new RustWriter(null, "graph")); + } + [Fact] + public void ThrowsOnNullClientNamespace() + { + Assert.Throws(() => new RustWriter("./", null)); + } +}