// ------------------------------------------------------------ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. // ------------------------------------------------------------ #nullable enable using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Text.Json; using System.Threading; using System.Threading.Tasks; #if ACTORS using Dapr.Actors; #endif using Dapr.Client; using Google.Protobuf; using Grpc.Net.Client; namespace Dapr { public abstract class TestClient { #if ACTORS internal static TestClient CreateForDaprHttpInterator(string? apiToken = null) { var handler = new CapturingHandler(); return new TestClient(new DaprHttpInteractor(handler, apiToken), handler); } #endif public static TestClient CreateForMessageHandler() { var handler = new CapturingHandler(); return new TestClient(handler, handler); } public static TestClient CreateForDaprClient(Action? configure = default) { var handler = new CapturingHandler(); var httpClient = new HttpClient(handler); var builder = new DaprClientBuilder(); configure?.Invoke(builder); builder.UseHttpClientFactory(() => httpClient); builder.UseGrpcChannelOptions(new GrpcChannelOptions() { HttpClient = httpClient, ThrowOperationCanceledOnCancellation = true, }); return new TestClient(builder.Build(), handler); } private static async Task WithTimeout(Task task, TimeSpan timeout, string message) { var tcs = new TaskCompletionSource(); using var cts = new CancellationTokenSource(timeout); using (cts.Token.Register((obj) => { var tcs = (TaskCompletionSource)obj!; tcs.SetException(new TimeoutException()); }, tcs)) { await (await Task.WhenAny(task, tcs.Task)); } } private static async Task WithTimeout(Task task, TimeSpan timeout, string message) { var tcs = new TaskCompletionSource(); using var cts = new CancellationTokenSource(timeout); using (cts.Token.Register((obj) => { var tcs = (TaskCompletionSource)obj!; tcs.SetException(new TimeoutException()); }, tcs)) { return await (await Task.WhenAny(task, tcs.Task)); } } public class TestHttpRequest { public TestHttpRequest(HttpRequestMessage request, CaptureToken capture, Task task) { this.Request = request; this.Capture = capture; this.Task = task; } public HttpRequestMessage Request { get; } private CaptureToken Capture { get; } private Task Task { get; } public async Task CompleteAsync(HttpResponseMessage response) { this.Capture.Response.SetResult(response); await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public async Task CompleteWithExceptionAsync(Exception ex) { this.Capture.Response.SetException(ex); await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public void Dismiss() { this.Capture.IsDismissed = true; } } public class TestHttpRequest { public TestHttpRequest(HttpRequestMessage request, CaptureToken capture, Task task) { this.Request = request; this.Capture = capture; this.Task = task; } public HttpRequestMessage Request { get; } private CaptureToken Capture { get; } private Task Task { get; } public async Task CompleteWithJsonAsync(TData value, JsonSerializerOptions options) { var bytes = JsonSerializer.SerializeToUtf8Bytes(value, options); var response = new HttpResponseMessage(HttpStatusCode.OK) { Content = new ByteArrayContent(bytes) }; response.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json") { CharSet = "UTF-8", }; return await CompleteAsync(response); } public async Task CompleteAsync(HttpResponseMessage response) { this.Capture.Response.SetResult(response); return await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public async Task CompleteWithExceptionAsync(Exception ex) { this.Capture.Response.SetException(ex); await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public void Dismiss() { this.Capture.IsDismissed = true; } } public class TestGrpcRequest { public TestGrpcRequest(HttpRequestMessage request, CaptureToken capture, Task task) { this.Request = request; this.Capture = capture; this.Task = task; } public HttpRequestMessage Request { get; } private CaptureToken Capture { get; } private Task Task { get; } public async Task GetRequestEnvelopeAsync() where TRequestEnvelope : IMessage, new() { return await GrpcUtils.GetRequestFromRequestMessageAsync(this.Request); } public async Task CompleteWithMessageAsync(TResponseEnvelope value) where TResponseEnvelope : IMessage { var content = await GrpcUtils.CreateResponseContent(value); var response = GrpcUtils.CreateResponse(HttpStatusCode.OK, content); await CompleteAsync(response); } public async Task CompleteAsync(HttpResponseMessage response) { this.Capture.Response.SetResult(response); await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public async Task CompleteWithExceptionAsync(Exception ex) { this.Capture.Response.SetException(ex); await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public void Dismiss() { this.Capture.IsDismissed = true; } } public class TestGrpcRequest { public TestGrpcRequest(HttpRequestMessage request, CaptureToken capture, Task task) { this.Request = request; this.Capture = capture; this.Task = task; } public HttpRequestMessage Request { get; } private CaptureToken Capture { get; } private Task Task { get; } public async Task GetRequestEnvelopeAsync() where TRequestEnvelope : IMessage, new() { return await GrpcUtils.GetRequestFromRequestMessageAsync(this.Request); } public async Task CompleteWithMessageAsync(TResponseEnvelope value) where TResponseEnvelope : IMessage { var content = await GrpcUtils.CreateResponseContent(value); var response = GrpcUtils.CreateResponse(HttpStatusCode.OK, content); return await CompleteAsync(response); } public async Task CompleteAsync(HttpResponseMessage response) { this.Capture.Response.SetResult(response); return await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public async Task CompleteWithExceptionAsync(Exception ex) { this.Capture.Response.SetException(ex); await WithTimeout(this.Task, TimeSpan.FromSeconds(10), "Waiting for response to be completed timed out."); } public void Dismiss() { this.Capture.IsDismissed = true; } } public class CapturingHandler : HttpMessageHandler { private readonly ConcurrentQueue requests = new ConcurrentQueue(); private readonly object @lock = new object(); private CaptureToken? current; public CaptureToken Capture() { lock (this.@lock) { if (this.current is CaptureToken) { throw new InvalidOperationException( "Capture operation started while already capturing. " + "Concurrent use of the test client is not supported."); } return (this.current = new CaptureToken()); } } public IEnumerable GetOutstandingRequests() { foreach (var request in this.requests) { if (request.IsComplete) { continue; } yield return request; } } protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { CaptureToken? capture; lock (this.@lock) { if ((capture = this.current) is CaptureToken) { this.current = default; } } capture ??= new CaptureToken(); this.requests.Enqueue(capture); capture.Request.SetResult(request); return capture.Response.Task; } } public class CaptureToken { public TaskCompletionSource Request { get; } = new TaskCompletionSource(); public TaskCompletionSource Response { get; } = new TaskCompletionSource(); public bool IsDismissed { get; set; } public bool IsComplete { get { return IsDismissed || // We assume that whomever started the work observed exceptions making the request. !Request.Task.IsCompletedSuccessfully || Response.Task.IsCompleted; } } public Task GetRequestAsync(TimeSpan timeout) { return WithTimeout(Request.Task, timeout, "Waiting for request to be queued timed out."); } } } public class TestClient : TestClient, IAsyncDisposable { public TestClient(TClient innerClient, CapturingHandler handler) { this.InnerClient = innerClient; this.Handler = handler; } public TClient InnerClient { get; } private CapturingHandler Handler { get; } public async Task CaptureHttpRequestAsync(Func operation) { var (request, capture, task) = await CaptureHttpRequestMessageAsync(operation); return new TestHttpRequest(request, capture, task); } public async Task> CaptureHttpRequestAsync(Func> operation) { var (request, capture, task) = await CaptureHttpRequestMessageAsync(operation); return new TestHttpRequest(request, capture, (Task)task); } public async Task CaptureGrpcRequestAsync(Func operation) { var (request, capture, task) = await CaptureHttpRequestMessageAsync(operation); return new TestGrpcRequest(request, capture, task); } public async Task> CaptureGrpcRequestAsync(Func> operation) { var (request, capture, task) = await CaptureHttpRequestMessageAsync(operation); return new TestGrpcRequest(request, capture, (Task)task); } private async Task<(HttpRequestMessage, CaptureToken, Task)> CaptureHttpRequestMessageAsync(Func operation) { var capture = this.Handler.Capture(); var task = operation(this.InnerClient); if (task.IsFaulted) { // If the task throws, we want to bubble that up eaglerly. await task; } HttpRequestMessage request; try { // Apply a 10 second timeout to waiting for the task to be queued. This is a very // generous timeout so if we hit it then it's likely a bug. request = await capture.GetRequestAsync(TimeSpan.FromSeconds(10)); } // If the original operation threw, report that instead of the timeout catch (TimeoutException) when (task.IsFaulted) { await task; throw; // unreachable } return (request, capture, task); } public ValueTask DisposeAsync() { (this.InnerClient as IDisposable)?.Dispose(); var requests = this.Handler.GetOutstandingRequests().ToArray(); if (requests.Length > 0) { throw new InvalidOperationException( "The client has 1 or more incomplete requests. " + "Use 'request.Dismiss()' if the test is uninterested in the response."); } return new ValueTask(); } } }