This commit is contained in:
Whit Waldo 2025-08-02 02:55:38 +02:00 committed by GitHub
commit 01b6f6cac2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 370 additions and 0 deletions

View File

@ -0,0 +1,141 @@
// ------------------------------------------------------------------------
// Copyright 2023 The Dapr Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ------------------------------------------------------------------------
using System.Collections.Immutable;
using System.Text;
using Dapr.Actors.Generators.Extensions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
namespace Dapr.Actors.Generators;
/// <summary>
/// Generates an extension method that can be used during dependency injection to register all actor types.
/// </summary>
[Generator]
public sealed class ActorRegistrationGenerator : IIncrementalGenerator
{
private const string DaprActorType = "Dapr.Actors.Runtime.Actor";
/// <summary>
/// Initializes the generator and registers the syntax receiver.
/// </summary>
/// <param name="context">The <see cref="T:Microsoft.CodeAnalysis.IncrementalGeneratorInitializationContext" /> to register callbacks on</param>
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var classDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsClassDeclaration(s),
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null);
var compilationAndClasses = context.CompilationProvider.Combine(classDeclarations.Collect());
context.RegisterSourceOutput(compilationAndClasses, static (spc, source) => Execute(source.Left, source.Right, spc));
}
private static bool IsClassDeclaration(SyntaxNode node) => node is ClassDeclarationSyntax;
private static INamedTypeSymbol? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
{
var classDeclaration = (ClassDeclarationSyntax)context.Node;
var model = context.SemanticModel;
if (model.GetDeclaredSymbol(classDeclaration) is not INamedTypeSymbol classSymbol)
{
return null;
}
var actorClass = context.SemanticModel.Compilation.GetTypeByMetadataName(DaprActorType);
return classSymbol.BaseType != null && classSymbol.BaseType.Equals(actorClass, SymbolEqualityComparer.Default) ? classSymbol : null;
}
private static void Execute(Compilation compilation, ImmutableArray<INamedTypeSymbol?> actorTypes,
SourceProductionContext context)
{
var validActorTypes = actorTypes.Where(static t => t is not null).Cast<INamedTypeSymbol>().ToList();
var source = GenerateActorRegistrationSource(compilation, validActorTypes);
context.AddSource("ActorRegistrationExtensions.g.cs", SourceText.From(source, Encoding.UTF8));
}
/// <summary>
/// Generates the source code for the actor registration method.
/// </summary>
/// <param name="compilation">The current compilation context.</param>
/// <param name="actorTypes">The list of actor types to register.</param>
/// <returns>The generated source code as a string.</returns>
private static string GenerateActorRegistrationSource(Compilation compilation, IReadOnlyList<INamedTypeSymbol> actorTypes)
{
#pragma warning disable RS1035
var registrations = string.Join(Environment.NewLine,
#pragma warning restore RS1035
actorTypes.Select(t => $"options.Actors.RegisterActor<{t.ToDisplayString()}>();"));
return $@"
using Microsoft.Extensions.DependencyInjection;
using Dapr.Actors.AspNetCore;
using Dapr.Actors.Runtime;
using Dapr.Actors;
using Dapr.Actors.AspNetCore;
/// <summary>
/// Extension methods for registering Dapr actors.
/// </summary>
public static class ActorRegistrationExtensions
{{
/// <summary>
/// Registers all discovered actor types with the Dapr actor runtime.
/// </summary>
/// <param name=""services"">The service collection to add the actors to.</param>
/// <param name=""includeTransientReferences"">Whether to include actor types from referenced assemblies.</param>
public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
{{
services.AddActors(options =>
{{
{registrations}
if (includeTransientReferences)
{{
{GenerateTransientActorRegistrations(compilation)}
}}
}});
}}
}}";
}
/// <summary>
/// Generates the registration code for actor types in referenced assemblies.
/// </summary>
/// <param name="compilation">The current compilation context.</param>
/// <returns>The generated registration code as a string.</returns>
private static string GenerateTransientActorRegistrations(Compilation compilation)
{
var actorRegistrations = new List<string>();
foreach (var reference in compilation.References)
{
if (compilation.GetAssemblyOrModuleSymbol(reference) is IAssemblySymbol referencedCompilation)
{
actorRegistrations.AddRange(from type in referencedCompilation.GlobalNamespace.GetNamespaceTypes()
where type.BaseType?.ToDisplayString() == DaprActorType
select $"options.Actors.RegisterActor<{type.ToDisplayString()}>();");
}
}
#pragma warning disable RS1035
return string.Join(Environment.NewLine, actorRegistrations);
#pragma warning restore RS1035
}
}

View File

@ -0,0 +1,33 @@
using Microsoft.CodeAnalysis;
namespace Dapr.Actors.Generators.Extensions;
internal static class INamespaceSymbolExtensions
{
/// <summary>
/// Recursively gets all the types in a namespace.
/// </summary>
/// <param name="namespaceSymbol">The namespace symbol to search.</param>
/// <returns>A collection of the named type symbols.</returns>
public static IEnumerable<INamedTypeSymbol> GetNamespaceTypes(this INamespaceSymbol namespaceSymbol)
{
foreach (var member in namespaceSymbol.GetMembers())
{
switch (member)
{
case INamespaceSymbol nestedNamespace:
{
foreach (var nestedType in nestedNamespace.GetNamespaceTypes())
{
yield return nestedType;
}
break;
}
case INamedTypeSymbol namedType:
yield return namedType;
break;
}
}
}
}

View File

@ -0,0 +1,156 @@
using System.Text;
using Dapr.Actors.Runtime;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Text;
using Microsoft.Extensions.DependencyInjection;
namespace Dapr.Actors.Generators.Test;
public class ActorRegistrationGeneratorTests
{
[Fact]
public void TestActorRegistrationGenerator_WithoutTransientReference()
{
const string source = @"
using Dapr.Actors.Runtime;
public class MyActor : Actor, IMyActor
{
public MyActor(ActorHost host) : base(host) { }
}
public interface IMyActor : IActor
{
}
";
const string expectedGeneratedCode = @"
using Microsoft.Extensions.DependencyInjection;
using Dapr.Actors.Runtime;
/// <summary>
/// Extension methods for registering Dapr actors.
/// </summary>
public static class ActorRegistrationExtensions
{
/// <summary>
/// Registers all discovered actor types with the Dapr actor runtime.
/// </summary>
/// <param name=""services"">The service collection to add the actors to.</param>
/// <param name=""includeTransientReferences"">Whether to include actor types from referenced assemblies.</param>
public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
{
services.AddActors(options =>
{
options.Actors.RegisterActor<MyActor>();
if (includeTransientReferences)
{
}
});
}
}";
var generatedCode = GetGeneratedCode(source);
Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim());
}
[Fact]
public void TestActorRegistrationGenerator_WithTransientReference()
{
const string source = @"
using Dapr.Actors.Runtime;
public class MyActor : Actor, IMyActor
{
public MyActor(ActorHost host) : base(host) { }
}
public interface IMyActor : IActor
{
}
";
const string referencedSource = @"
using Dapr.Actors.Runtime;
public class TransientActor : Actor, ITransientActor
{
public TransientActor(ActorHost host) : base(host) { }
}
public interface ITransientActor : IActor
{
}
";
const string expectedGeneratedCode = @"
using Microsoft.Extensions.DependencyInjection;
using Dapr.Actors.Runtime;
/// <summary>
/// Extension methods for registering Dapr actors.
/// </summary>
public static class ActorRegistrationExtensions
{
/// <summary>
/// Registers all discovered actor types with the Dapr actor runtime.
/// </summary>
/// <param name=""services"">The service collection to add the actors to.</param>
/// <param name=""includeTransientReferences"">Whether to include actor types from referenced assemblies.</param>
public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
{
services.AddActors(options =>
{
options.Actors.RegisterActor<MyActor>();
if (includeTransientReferences)
{
options.Actors.RegisterActor<TransientActor>();
}
});
}
}";
var generatedCode = GetGeneratedCode(source, referencedSource);
Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim());
}
private static string GetGeneratedCode(string source, string? referencedSource = null)
{
var syntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(source, Encoding.UTF8));
var references = new List<MetadataReference>
{
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location),
MetadataReference.CreateFromFile(typeof(Actor).Assembly.Location),
MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location)
};
var compilation = CSharpCompilation.Create("TestCompilation",
new[] { syntaxTree },
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
if (referencedSource != null)
{
var referencedSyntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(referencedSource, Encoding.UTF8));
var referencedCompilation = CSharpCompilation.Create("ReferencedCompilation",
new[] { referencedSyntaxTree },
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
compilation = compilation.AddReferences(referencedCompilation.ToMetadataReference());
}
var generator = new ActorRegistrationGenerator();
var driver = CSharpGeneratorDriver.Create(generator);
driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics);
var generatedTrees = outputCompilation.SyntaxTrees.Skip(1).ToList();
Assert.Single(generatedTrees);
var generatedCode = generatedTrees[0].ToString();
return generatedCode;
}
}

View File

@ -0,0 +1,40 @@
using Dapr.Actors.Generators.Extensions;
using Microsoft.CodeAnalysis.CSharp;
namespace Dapr.Actors.Generators.Test.Extensions;
public class INamespaceExtensionsTests
{
[Fact]
public void GetNamespaceTypes_ReturnsAllTypesInNamespace()
{
// Arrange
const string source = @"
namespace TestNamespace
{
public class ClassA { }
public class ClassB { }
namespace NestedNamespace
{
public class ClassC { }
}
}";
var syntaxTree = CSharpSyntaxTree.ParseText(source);
var compilation = CSharpCompilation.Create("TestCompilation", new[] { syntaxTree });
var namespaceSymbol = compilation.GlobalNamespace.GetNamespaceMembers().FirstOrDefault(n => n.Name == "TestNamespace");
// Act
if (namespaceSymbol != null)
{
var types = namespaceSymbol.GetNamespaceTypes().ToList();
// Assert
Assert.NotNull(namespaceSymbol);
Assert.Equal(3, types.Count);
Assert.Contains(types, t => t.Name == "ClassA");
Assert.Contains(types, t => t.Name == "ClassB");
Assert.Contains(types, t => t.Name == "ClassC");
}
}
}