spire-plugin-sdk/plugintest/serve.go

192 lines
5.4 KiB
Go

package plugintest
import (
"context"
"errors"
"sort"
"sync"
"testing"
"time"
"github.com/hashicorp/go-hclog"
goplugin "github.com/hashicorp/go-plugin"
"github.com/spiffe/spire-plugin-sdk/internal"
"github.com/spiffe/spire-plugin-sdk/pluginsdk"
"github.com/spiffe/spire-plugin-sdk/private"
"google.golang.org/grpc"
)
// Config is the test configuration for the plugin. It defines which plugin
// services, and host services are wired up.
type Config struct {
// PluginServer is the plugin server implemented by the plugin.
PluginServer pluginsdk.PluginServer
// PluginClient is the plugin clients that the test needs have initialized
// to facilitate testing the plugin implementations. This field is
// optional.
PluginClient pluginsdk.PluginClient
// ServiceServers are the service servers implemented by the plugin. This
// field is optional.
ServiceServers []pluginsdk.ServiceServer
// ServiceClients are the service clients that the test needs have
// initialized to facilitate testing service implementations on the plugin.
// This field is optional.
ServiceClients []pluginsdk.ServiceClient
// HostServiceServers are the host services that the test wants to offer to
// the plugin. This field is optional.
HostServiceServers []pluginsdk.ServiceServer
// Logger is an optional logger for capturing log events from the plugin.
// This field is optional.
Logger hclog.Logger
}
// ServeInBackground serves the plugin in background with the provided
// configuration. The plugin will be unloaded when the test is over.
func ServeInBackground(t *testing.T, config Config) {
wg := new(sync.WaitGroup)
t.Cleanup(wg.Wait)
switch {
case config.PluginServer == nil:
t.Fatal("PluginServer is required")
case config.PluginClient == nil:
t.Fatal("PluginClient is required")
}
ctx, cancel := context.WithCancel(context.Background())
closeCh := make(chan struct{})
t.Cleanup(func() {
cancel()
<-closeCh
})
serverLogger := config.Logger
if serverLogger == nil {
serverLogger = hclog.NewNullLogger()
}
reattachConfigCh := make(chan *goplugin.ReattachConfig)
if config.PluginServer == nil {
t.Fatal("PluginServer must be provided")
}
wg.Add(1)
go func() {
defer wg.Done()
internal.Serve(serverLogger, hclog.NewNullLogger(), config.PluginServer, config.ServiceServers, &goplugin.ServeTestConfig{
Context: ctx,
CloseCh: closeCh,
ReattachConfigCh: reattachConfigCh,
})
}()
var reattachConfig *goplugin.ReattachConfig
select {
case reattachConfig = <-reattachConfigCh:
case <-time.After(time.Second * 10):
t.Fatal("timed out waiting for plugin to launch")
}
hcClient := &hcTestClient{wg: wg, hostServiceServers: config.HostServiceServers}
client := goplugin.NewClient(&goplugin.ClientConfig{
HandshakeConfig: internal.ClientHandshakeConfig(config.PluginClient),
Plugins: map[string]goplugin.Plugin{
"TEST": hcClient,
},
Logger: hclog.NewNullLogger(),
AllowedProtocols: []goplugin.Protocol{goplugin.ProtocolGRPC},
Reattach: reattachConfig,
})
t.Cleanup(client.Kill)
grpcClient, err := client.Client()
if err != nil {
t.Fatalf("failed to open gRPC client to plugin: %v", err)
}
t.Cleanup(func() { grpcClient.Close() })
rawConn, err := grpcClient.Dispense("TEST")
if err != nil {
t.Fatalf("failed to dispense plugin client: %v", err)
}
conn, ok := rawConn.(*grpc.ClientConn)
if !ok {
// This is purely defensive; the hcTestClient GRPCClient method
// returns a *grpc.ClientConn.
t.Fatalf("expected %T; got %T", conn, rawConn)
}
t.Cleanup(func() { conn.Close() })
grpcServiceNames, err := private.Init(ctx, conn, serverGRPCServiceNames(config.HostServiceServers))
if err != nil {
t.Fatalf("failed to initialize plugin: %v", err)
}
t.Cleanup(func() {
if err := private.Deinit(ctx, conn); err != nil {
t.Fatalf("failed to deinitialize plugin: %v", err)
}
})
assertInitClient(t, conn, config.PluginClient, grpcServiceNames)
for _, serviceClient := range config.ServiceClients {
assertInitClient(t, conn, serviceClient, grpcServiceNames)
}
}
func assertInitClient(t *testing.T, conn grpc.ClientConnInterface, client pluginsdk.ServiceClient, grpcServiceNames []string) {
for _, grpcServiceName := range grpcServiceNames {
if client.GRPCServiceName() == grpcServiceName {
client.InitClient(conn)
return
}
}
t.Fatalf("%s was not provided by the plugin as expected", client.GRPCServiceName())
}
type hcTestClient struct {
goplugin.NetRPCUnsupportedPlugin
wg *sync.WaitGroup
hostServiceServers []pluginsdk.ServiceServer
}
var _ goplugin.GRPCPlugin = (*hcTestClient)(nil)
func (p *hcTestClient) GRPCServer(b *goplugin.GRPCBroker, s *grpc.Server) error {
return errors.New("not implemented host side")
}
func (p *hcTestClient) GRPCClient(ctx context.Context, b *goplugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
p.wg.Add(1)
go func() {
defer p.wg.Done()
b.AcceptAndServe(private.HostServiceProviderID, func(opts []grpc.ServerOption) *grpc.Server {
s := grpc.NewServer(opts...)
for _, hostServiceServer := range p.hostServiceServers {
hostServiceServer.RegisterServer(s)
}
return s
})
}()
return c, nil
}
func serverGRPCServiceNames(servers []pluginsdk.ServiceServer) []string {
var grpcServiceNames []string
for _, server := range servers {
grpcServiceNames = append(grpcServiceNames, server.GRPCServiceName())
}
sort.Strings(grpcServiceNames)
return grpcServiceNames
}