Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 109 additions & 10 deletions GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
namespace GFramework.Cqrs.Tests.Cqrs;

/// <summary>
/// 验证 CQRS dispatcher 会缓存热路径中的服务类型构造结果
/// 验证 CQRS dispatcher 会缓存热路径中的服务类型与调用委托
/// </summary>
[TestFixture]
internal sealed class CqrsDispatcherCacheTests
{
private MicrosoftDiContainer? _container;
private ArchitectureContext? _context;

/// <summary>
/// 初始化测试上下文。
/// </summary>
Expand All @@ -29,6 +32,7 @@ public void SetUp()

_container.Freeze();
_context = new ArchitectureContext(_container);
ClearDispatcherCaches();
}

/// <summary>
Expand All @@ -41,20 +45,17 @@ public void TearDown()
_container = null;
}

private MicrosoftDiContainer? _container;
private ArchitectureContext? _context;

/// <summary>
/// 验证相同消息类型重复分发时,不会重复扩张服务类型缓存
/// 验证相同消息类型重复分发时,不会重复扩张服务类型与调用委托缓存
/// </summary>
[Test]
public async Task Dispatcher_Should_Cache_Service_Types_After_First_Dispatch()
{
var notificationServiceTypes = GetCacheField("NotificationHandlerServiceTypes");
var requestServiceTypes = GetCacheField("RequestServiceTypes");
var streamServiceTypes = GetCacheField("StreamHandlerServiceTypes");
var requestInvokers = GetCacheField("RequestInvokers");
var requestPipelineInvokers = GetCacheField("RequestPipelineInvokers");
var requestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers");
var requestPipelineInvokers = GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(int), "Invokers");
var notificationInvokers = GetCacheField("NotificationInvokers");
var streamInvokers = GetCacheField("StreamInvokers");

Expand Down Expand Up @@ -104,14 +105,42 @@ public async Task Dispatcher_Should_Cache_Service_Types_After_First_Dispatch()
});
}

/// <summary>
/// 验证 request 调用委托会按响应类型分别缓存,避免不同响应类型共用 object 结果桥接。
/// </summary>
[Test]
public async Task Dispatcher_Should_Cache_Request_Invokers_Per_Response_Type()
{
var intRequestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers");
var stringRequestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(string), "Invokers");

var intBefore = intRequestInvokers.Count;
var stringBefore = stringRequestInvokers.Count;

await _context!.SendRequestAsync(new DispatcherCacheRequest());
await _context.SendRequestAsync(new DispatcherStringCacheRequest());

var intAfterFirstDispatch = intRequestInvokers.Count;
var stringAfterFirstDispatch = stringRequestInvokers.Count;

await _context.SendRequestAsync(new DispatcherCacheRequest());
await _context.SendRequestAsync(new DispatcherStringCacheRequest());

Assert.Multiple(() =>
{
Assert.That(intAfterFirstDispatch, Is.EqualTo(intBefore + 1));
Assert.That(stringAfterFirstDispatch, Is.EqualTo(stringBefore + 1));
Assert.That(intRequestInvokers.Count, Is.EqualTo(intAfterFirstDispatch));
Assert.That(stringRequestInvokers.Count, Is.EqualTo(stringAfterFirstDispatch));
});
}

/// <summary>
/// 通过反射读取 dispatcher 的静态缓存字典。
/// </summary>
private static IDictionary GetCacheField(string fieldName)
{
var dispatcherType = typeof(CqrsReflectionFallbackAttribute).Assembly
.GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!;

var dispatcherType = GetDispatcherType();
var field = dispatcherType.GetField(
fieldName,
BindingFlags.NonPublic | BindingFlags.Static);
Expand All @@ -123,6 +152,57 @@ private static IDictionary GetCacheField(string fieldName)
$"Dispatcher cache field {fieldName} does not implement IDictionary.");
}

/// <summary>
/// 清空本测试依赖的 dispatcher 静态缓存,避免跨用例共享进程级状态导致断言漂移。
/// </summary>
private static void ClearDispatcherCaches()
{
GetCacheField("NotificationHandlerServiceTypes").Clear();
GetCacheField("RequestServiceTypes").Clear();
GetCacheField("StreamHandlerServiceTypes").Clear();
GetCacheField("NotificationInvokers").Clear();
GetCacheField("StreamInvokers").Clear();
GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers").Clear();
GetGenericCacheField("RequestInvokerCache`1", typeof(string), "Invokers").Clear();
GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(int), "Invokers").Clear();
GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(string), "Invokers").Clear();
}

/// <summary>
/// 通过反射读取 dispatcher 嵌套泛型缓存类型上的静态缓存字典。
/// </summary>
private static IDictionary GetGenericCacheField(string nestedTypeName, Type genericTypeArgument, string fieldName)
{
var nestedGenericType = GetDispatcherType().GetNestedType(
nestedTypeName,
BindingFlags.NonPublic);

Assert.That(nestedGenericType, Is.Not.Null, $"Missing dispatcher nested cache type {nestedTypeName}.");

var closedNestedType = nestedGenericType!.MakeGenericType(genericTypeArgument);
var field = closedNestedType.GetField(
fieldName,
BindingFlags.NonPublic | BindingFlags.Static);

Assert.That(
field,
Is.Not.Null,
$"Missing dispatcher nested cache field {nestedTypeName}.{fieldName} for {genericTypeArgument.FullName}.");

return field!.GetValue(null) as IDictionary
?? throw new InvalidOperationException(
$"Dispatcher nested cache field {nestedTypeName}.{fieldName} does not implement IDictionary.");
}

/// <summary>
/// 获取 CQRS dispatcher 运行时类型。
/// </summary>
private static Type GetDispatcherType()
{
return typeof(CqrsReflectionFallbackAttribute).Assembly
.GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!;
}

/// <summary>
/// 消费整个异步流,确保建流路径被真实执行。
/// </summary>
Expand Down Expand Up @@ -154,6 +234,11 @@ internal sealed record DispatcherCacheStreamRequest : IStreamRequest<int>;
/// </summary>
internal sealed record DispatcherPipelineCacheRequest : IRequest<int>;

/// <summary>
/// 用于验证按响应类型分层 request invoker 缓存的测试请求。
/// </summary>
internal sealed record DispatcherStringCacheRequest : IRequest<string>;

/// <summary>
/// 处理 <see cref="DispatcherCacheRequest" />。
/// </summary>
Expand Down Expand Up @@ -213,6 +298,20 @@ public ValueTask<int> Handle(DispatcherPipelineCacheRequest request, Cancellatio
}
}

/// <summary>
/// 处理 <see cref="DispatcherStringCacheRequest" />。
/// </summary>
internal sealed class DispatcherStringCacheRequestHandler : IRequestHandler<DispatcherStringCacheRequest, string>
{
/// <summary>
/// 返回固定字符串,供按响应类型缓存测试验证 string 路径。
/// </summary>
public ValueTask<string> Handle(DispatcherStringCacheRequest request, CancellationToken cancellationToken)
{
return ValueTask.FromResult("dispatcher-cache");
}
}

/// <summary>
/// 为 <see cref="DispatcherPipelineCacheRequest" /> 提供最小 pipeline 行为,
/// 用于命中 dispatcher 的 pipeline invoker 缓存分支。
Expand Down
76 changes: 42 additions & 34 deletions GFramework.Cqrs/Internal/CqrsDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@ internal sealed class CqrsDispatcher(
IIocContainer container,
ILogger logger) : ICqrsRuntime
{
// 进程级缓存:按请求/响应类型缓存直接处理器调用委托,避免热路径重复反射。
// 线程安全依赖 ConcurrentDictionary;缓存与进程同寿命,默认假设请求类型集合有限且稳定。
private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestInvoker>
RequestInvokers = new();

// 进程级缓存:缓存带 pipeline 的请求调用委托,减少每次分发时的反射与表达式重建开销。
// 若后续引入动态生成请求类型,需要重新评估该缓存的增长边界。
private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestPipelineInvoker>
RequestPipelineInvokers = new();

// 进程级缓存:缓存通知调用委托,复用并发安全字典以支撑多线程发布路径。
private static readonly ConcurrentDictionary<Type, NotificationInvoker> NotificationInvokers = new();

Expand Down Expand Up @@ -131,20 +121,18 @@ public async ValueTask<TResponse> SendAsync<TResponse>(

if (behaviors.Count == 0)
{
var invoker = RequestInvokers.GetOrAdd(
(requestType, typeof(TResponse)),
static key => CreateRequestInvoker(key.RequestType, key.ResponseType));
var invoker = RequestInvokerCache<TResponse>.Invokers.GetOrAdd(
requestType,
CreateRequestInvoker<TResponse>);

var result = await invoker(handler, request, cancellationToken);
return result is null ? default! : (TResponse)result;
return await invoker(handler, request, cancellationToken);
}

var pipelineInvoker = RequestPipelineInvokers.GetOrAdd(
(requestType, typeof(TResponse)),
static key => CreateRequestPipelineInvoker(key.RequestType, key.ResponseType));
var pipelineInvoker = RequestPipelineInvokerCache<TResponse>.Invokers.GetOrAdd(
requestType,
CreateRequestPipelineInvoker<TResponse>);

var pipelineResult = await pipelineInvoker(handler, behaviors, request, cancellationToken);
return pipelineResult is null ? default! : (TResponse)pipelineResult;
return await pipelineInvoker(handler, behaviors, request, cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -200,21 +188,23 @@ private static void PrepareHandler(object handler, ICqrsContext context)
/// <summary>
/// 生成请求处理器调用委托,避免每次发送都重复反射。
/// </summary>
private static RequestInvoker CreateRequestInvoker(Type requestType, Type responseType)
private static RequestInvoker<TResponse> CreateRequestInvoker<TResponse>(Type requestType)
{
var method = RequestHandlerInvokerMethodDefinition
.MakeGenericMethod(requestType, responseType);
return (RequestInvoker)Delegate.CreateDelegate(typeof(RequestInvoker), method);
.MakeGenericMethod(requestType, typeof(TResponse));
return (RequestInvoker<TResponse>)Delegate.CreateDelegate(typeof(RequestInvoker<TResponse>), method);
}

/// <summary>
/// 生成带管道行为的请求处理委托,避免每次发送都重复反射。
/// </summary>
private static RequestPipelineInvoker CreateRequestPipelineInvoker(Type requestType, Type responseType)
private static RequestPipelineInvoker<TResponse> CreateRequestPipelineInvoker<TResponse>(Type requestType)
{
var method = RequestPipelineInvokerMethodDefinition
.MakeGenericMethod(requestType, responseType);
return (RequestPipelineInvoker)Delegate.CreateDelegate(typeof(RequestPipelineInvoker), method);
.MakeGenericMethod(requestType, typeof(TResponse));
return (RequestPipelineInvoker<TResponse>)Delegate.CreateDelegate(
typeof(RequestPipelineInvoker<TResponse>),
method);
}

/// <summary>
Expand All @@ -240,22 +230,21 @@ private static StreamInvoker CreateStreamInvoker(Type requestType, Type response
/// <summary>
/// 执行已强类型化的请求处理器调用。
/// </summary>
private static async ValueTask<object?> InvokeRequestHandlerAsync<TRequest, TResponse>(
private static ValueTask<TResponse> InvokeRequestHandlerAsync<TRequest, TResponse>(
object handler,
object request,
CancellationToken cancellationToken)
where TRequest : IRequest<TResponse>
{
var typedHandler = (IRequestHandler<TRequest, TResponse>)handler;
var typedRequest = (TRequest)request;
var result = await typedHandler.Handle(typedRequest, cancellationToken);
return result;
return typedHandler.Handle(typedRequest, cancellationToken);
}

/// <summary>
/// 执行包含管道行为链的请求处理。
/// </summary>
private static async ValueTask<object?> InvokeRequestPipelineAsync<TRequest, TResponse>(
private static ValueTask<TResponse> InvokeRequestPipelineAsync<TRequest, TResponse>(
object handler,
IReadOnlyList<object> behaviors,
object request,
Expand All @@ -275,8 +264,7 @@ private static StreamInvoker CreateStreamInvoker(Type requestType, Type response
next = (message, token) => behavior.Handle(message, currentNext, token);
}

var result = await next(typedRequest, cancellationToken);
return result;
return next(typedRequest, cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -307,10 +295,12 @@ private static object InvokeStreamHandler<TRequest, TResponse>(
return typedHandler.Handle(typedRequest, cancellationToken);
}

private delegate ValueTask<object?> RequestInvoker(object handler, object request,
private delegate ValueTask<TResponse> RequestInvoker<TResponse>(
object handler,
object request,
CancellationToken cancellationToken);

private delegate ValueTask<object?> RequestPipelineInvoker(
private delegate ValueTask<TResponse> RequestPipelineInvoker<TResponse>(
object handler,
IReadOnlyList<object> behaviors,
object request,
Expand All @@ -321,5 +311,23 @@ private delegate ValueTask NotificationInvoker(object handler, object notificati

private delegate object StreamInvoker(object handler, object request, CancellationToken cancellationToken);

/// <summary>
/// 按响应类型分层缓存 request 处理器调用委托,避免 value-type 响应在 object 桥接中产生装箱。
/// </summary>
/// <typeparam name="TResponse">请求响应类型。</typeparam>
private static class RequestInvokerCache<TResponse>
{
internal static readonly ConcurrentDictionary<Type, RequestInvoker<TResponse>> Invokers = new();
}

/// <summary>
/// 按响应类型分层缓存带 pipeline 的 request 调用委托,避免 pipeline 热路径上的额外装箱。
/// </summary>
/// <typeparam name="TResponse">请求响应类型。</typeparam>
private static class RequestPipelineInvokerCache<TResponse>
{
internal static readonly ConcurrentDictionary<Type, RequestPipelineInvoker<TResponse>> Invokers = new();
}

private readonly record struct RequestServiceTypeSet(Type HandlerType, Type BehaviorType);
}
23 changes: 22 additions & 1 deletion GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ public static class GeneratorTest<TGenerator>
public static async Task RunAsync(
string source,
params (string filename, string content)[] generatedSources)
{
await RunAsync(
source,
additionalReferences: [],
generatedSources);
}

/// <summary>
/// 运行源代码生成器测试,并为测试编译显式追加元数据引用。
/// </summary>
/// <param name="source">输入的源代码。</param>
/// <param name="additionalReferences">附加元数据引用,用于构造多程序集场景。</param>
/// <param name="generatedSources">期望生成的源文件集合,包含文件名和内容的元组。</param>
/// <returns>异步操作任务。</returns>
public static async Task RunAsync(
string source,
IEnumerable<MetadataReference> additionalReferences,
params (string filename, string content)[] generatedSources)
{
var test = new CSharpSourceGeneratorTest<TGenerator, DefaultVerifier>
{
Expand All @@ -31,6 +49,9 @@ public static async Task RunAsync(
test.TestState.GeneratedSources.Add(
(typeof(TGenerator), filename, NormalizeLineEndings(content)));

foreach (var additionalReference in additionalReferences)
test.TestState.AdditionalReferences.Add(additionalReference);

await test.RunAsync();
}

Expand All @@ -46,4 +67,4 @@ private static string NormalizeLineEndings(string content)
.Replace("\r", "\n", StringComparison.Ordinal)
.Replace("\n", Environment.NewLine, StringComparison.Ordinal);
}
}
}
Loading
Loading