Conversation builder consistency changes (#1423)

* Corrected several unit tests

Signed-off-by: Whit Waldo <whit.waldo@innovian.net>

* Updated extension name for consistency

Signed-off-by: Whit Waldo <whit.waldo@innovian.net>

* Updated registration name for consistency

Signed-off-by: Whit Waldo <whit.waldo@innovian.net>

---------

Signed-off-by: Whit Waldo <whit.waldo@innovian.net>
This commit is contained in:
Whit Waldo 2024-12-11 13:42:23 -06:00 committed by GitHub
parent 8bc031887e
commit ccf2bfdce3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 121 additions and 14 deletions

View File

@ -3,7 +3,7 @@ using Dapr.AI.Conversation.Extensions;
var builder = WebApplication.CreateBuilder(args);
builder.Services.AddDaprAiConversation();
builder.Services.AddDaprConversationClient();
var app = builder.Build();

View File

@ -26,7 +26,7 @@ public static class DaprAiConversationBuilderExtensions
/// Registers the necessary functionality for the Dapr AI conversation functionality.
/// </summary>
/// <returns></returns>
public static IDaprAiConversationBuilder AddDaprAiConversation(this IServiceCollection services, Action<IServiceProvider, DaprConversationClientBuilder>? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton)
public static IDaprAiConversationBuilder AddDaprConversationClient(this IServiceCollection services, Action<IServiceProvider, DaprConversationClientBuilder>? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton)
{
ArgumentNullException.ThrowIfNull(services, nameof(services));

View File

@ -13,7 +13,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Dapr.AI.Conversation;
using Dapr.AI.Conversation.Extensions;
using Microsoft.Extensions.Configuration;
@ -34,7 +36,7 @@ public class DaprAiConversationBuilderExtensionsTest
var services = new ServiceCollection();
services.AddSingleton<IConfiguration>(configuration);
services.AddDaprAiConversation();
services.AddDaprConversationClient();
var app = services.BuildServiceProvider();
@ -45,18 +47,66 @@ public class DaprAiConversationBuilderExtensionsTest
}
[Fact]
public void AddDaprAiConversation_WithoutConfigure_ShouldAddServices()
public void AddDaprConversationClient_RegistersDaprClientOnlyOnce()
{
var services = new ServiceCollection();
var builder = services.AddDaprAiConversation();
var clientBuilder = new Action<IServiceProvider, DaprConversationClientBuilder>((sp, builder) =>
{
builder.UseDaprApiToken("abc");
});
services.AddDaprConversationClient(); //Sets a default API token value of an empty string
services.AddDaprConversationClient(clientBuilder); //Sets the API token value
var serviceProvider = services.BuildServiceProvider();
var daprConversationClient = serviceProvider.GetService<DaprConversationClient>();
Assert.NotNull(daprConversationClient!.HttpClient);
Assert.False(daprConversationClient.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var _));
}
[Fact]
public void AddDaprConversationClient_RegistersUsingDependencyFromIServiceProvider()
{
var services = new ServiceCollection();
services.AddSingleton<TestSecretRetriever>();
services.AddDaprConversationClient((provider, builder) =>
{
var configProvider = provider.GetRequiredService<TestSecretRetriever>();
var apiToken = configProvider.GetApiTokenValue();
builder.UseDaprApiToken(apiToken);
});
var serviceProvider = services.BuildServiceProvider();
var client = serviceProvider.GetRequiredService<DaprConversationClient>();
//Validate it's set on the GrpcClient - note that it doesn't get set on the HttpClient
Assert.NotNull(client);
Assert.NotNull(client.DaprApiToken);
Assert.Equal("abcdef", client.DaprApiToken);
Assert.NotNull(client.HttpClient);
if (!client.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var daprApiToken))
{
Assert.Fail();
}
Assert.Equal("abcdef", daprApiToken.FirstOrDefault());
}
[Fact]
public void AddDaprConversationClient_WithoutConfigure_ShouldAddServices()
{
var services = new ServiceCollection();
var builder = services.AddDaprConversationClient();
Assert.NotNull(builder);
}
[Fact]
public void AddDaprAiConversation_RegistersIHttpClientFactory()
public void AddDaprConversationClient_RegistersIHttpClientFactory()
{
var services = new ServiceCollection();
services.AddDaprAiConversation();
services.AddDaprConversationClient();
var serviceProvider = services.BuildServiceProvider();
var httpClientFactory = serviceProvider.GetService<IHttpClientFactory>();
@ -67,9 +117,66 @@ public class DaprAiConversationBuilderExtensionsTest
}
[Fact]
public void AddDaprAiConversation_NullServices_ShouldThrowException()
public void AddDaprConversationClient_NullServices_ShouldThrowException()
{
IServiceCollection services = null;
Assert.Throws<ArgumentNullException>(() => services.AddDaprAiConversation());
Assert.Throws<ArgumentNullException>(() => services.AddDaprConversationClient());
}
[Fact]
public void AddDaprConversationClient_ShouldRegisterSingleton_WhenLifetimeIsSingleton()
{
var services = new ServiceCollection();
services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Singleton);
var serviceProvider = services.BuildServiceProvider();
var daprConversationClient1 = serviceProvider.GetService<DaprConversationClient>();
var daprConversationClient2 = serviceProvider.GetService<DaprConversationClient>();
Assert.NotNull(daprConversationClient1);
Assert.NotNull(daprConversationClient2);
Assert.Same(daprConversationClient1, daprConversationClient2);
}
[Fact]
public async Task AddDaprConversationClient_ShouldRegisterScoped_WhenLifetimeIsScoped()
{
var services = new ServiceCollection();
services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Scoped);
var serviceProvider = services.BuildServiceProvider();
await using var scope1 = serviceProvider.CreateAsyncScope();
var daprConversationClient1 = scope1.ServiceProvider.GetService<DaprConversationClient>();
await using var scope2 = serviceProvider.CreateAsyncScope();
var daprConversationClient2 = scope2.ServiceProvider.GetService<DaprConversationClient>();
Assert.NotNull(daprConversationClient1);
Assert.NotNull(daprConversationClient2);
Assert.NotSame(daprConversationClient1, daprConversationClient2);
}
[Fact]
public void AddDaprConversationClient_ShouldRegisterTransient_WhenLifetimeIsTransient()
{
var services = new ServiceCollection();
services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Transient);
var serviceProvider = services.BuildServiceProvider();
var daprConversationClient1 = serviceProvider.GetService<DaprConversationClient>();
var daprConversationClient2 = serviceProvider.GetService<DaprConversationClient>();
Assert.NotNull(daprConversationClient1);
Assert.NotNull(daprConversationClient2);
Assert.NotSame(daprConversationClient1, daprConversationClient2);
}
private class TestSecretRetriever
{
public string GetApiTokenValue() => "abcdef";
}
}

View File

@ -89,7 +89,7 @@ public class DaprJobsServiceCollectionExtensionsTest
services.AddDaprJobsClient((provider, builder) =>
{
var configProvider = provider.GetRequiredService<TestSecretRetriever>();
var apiToken = TestSecretRetriever.GetApiTokenValue();
var apiToken = configProvider.GetApiTokenValue();
builder.UseDaprApiToken(apiToken);
});
@ -114,7 +114,7 @@ public class DaprJobsServiceCollectionExtensionsTest
{
var services = new ServiceCollection();
services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Singleton);
services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Singleton);
var serviceProvider = services.BuildServiceProvider();
var daprJobsClient1 = serviceProvider.GetService<DaprJobsClient>();
@ -131,7 +131,7 @@ public class DaprJobsServiceCollectionExtensionsTest
{
var services = new ServiceCollection();
services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Scoped);
services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Scoped);
var serviceProvider = services.BuildServiceProvider();
await using var scope1 = serviceProvider.CreateAsyncScope();
@ -150,7 +150,7 @@ public class DaprJobsServiceCollectionExtensionsTest
{
var services = new ServiceCollection();
services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Transient);
services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Transient);
var serviceProvider = services.BuildServiceProvider();
var daprJobsClient1 = serviceProvider.GetService<DaprJobsClient>();
@ -163,6 +163,6 @@ public class DaprJobsServiceCollectionExtensionsTest
private class TestSecretRetriever
{
public static string GetApiTokenValue() => "abcdef";
public string GetApiTokenValue() => "abcdef";
}
}