167 lines
4.5 KiB
Go
167 lines
4.5 KiB
Go
// Copyright The OpenTelemetry Authors
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package configgrpc
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
|
|
"go.opentelemetry.io/collector/component"
|
|
"go.opentelemetry.io/collector/config/configmiddleware"
|
|
"go.opentelemetry.io/collector/config/confignet"
|
|
"go.opentelemetry.io/collector/config/configtls"
|
|
"go.opentelemetry.io/collector/extension"
|
|
"go.opentelemetry.io/collector/extension/extensionmiddleware"
|
|
"go.opentelemetry.io/collector/extension/extensionmiddleware/extensionmiddlewaretest"
|
|
)
|
|
|
|
// contextKey is a private type for keys defined in this test.
|
|
type contextKey int
|
|
|
|
// Key for the slice of middleware names in the context.
|
|
const middlewareCallsKey contextKey = 0
|
|
|
|
// getMiddlewareCalls retrieves the middleware calls from context or returns an empty slice.
|
|
func getMiddlewareCalls(ctx context.Context) []string {
|
|
calls, ok := ctx.Value(middlewareCallsKey).([]string)
|
|
if !ok {
|
|
return []string{}
|
|
}
|
|
return calls
|
|
}
|
|
|
|
// testServerMiddleware is a test implementation of configmiddleware.Middleware
|
|
type testServerMiddleware struct {
|
|
extension.Extension
|
|
extensionmiddleware.GetGRPCServerOptionsFunc
|
|
}
|
|
|
|
func newTestServerMiddleware(name string) extension.Extension {
|
|
return &testServerMiddleware{
|
|
Extension: extensionmiddlewaretest.NewNop(),
|
|
GetGRPCServerOptionsFunc: func() ([]grpc.ServerOption, error) {
|
|
return []grpc.ServerOption{grpc.ChainUnaryInterceptor(
|
|
func(
|
|
ctx context.Context,
|
|
req any, _ *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (any, error) {
|
|
ctx = context.WithValue(ctx, middlewareCallsKey, append(getMiddlewareCalls(ctx), name))
|
|
return handler(ctx, req)
|
|
})}, nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestGrpcServerUnaryInterceptor(t *testing.T) {
|
|
// Register two test extensions
|
|
host := &mockHost{
|
|
ext: map[component.ID]component.Component{
|
|
component.MustNewID("test1"): newTestServerMiddleware("test1"),
|
|
component.MustNewID("test2"): newTestServerMiddleware("test2"),
|
|
},
|
|
}
|
|
|
|
// Setup the server with both middleware options
|
|
server := &grpcTraceServer{}
|
|
var addr string
|
|
|
|
// Create the server with middleware interceptors
|
|
{
|
|
var srv *grpc.Server
|
|
srv, addr = server.startTestServerWithHost(t, ServerConfig{
|
|
NetAddr: confignet.AddrConfig{
|
|
Endpoint: "localhost:0",
|
|
Transport: confignet.TransportTypeTCP,
|
|
},
|
|
Middlewares: []configmiddleware.Config{
|
|
newTestMiddlewareConfig("test1"),
|
|
newTestMiddlewareConfig("test2"),
|
|
},
|
|
}, host)
|
|
defer srv.Stop()
|
|
}
|
|
|
|
// Send a request to trigger the interceptors
|
|
resp, errResp := sendTestRequest(t, ClientConfig{
|
|
Endpoint: addr,
|
|
TLSSetting: configtls.ClientConfig{
|
|
Insecure: true,
|
|
},
|
|
})
|
|
require.NoError(t, errResp)
|
|
require.NotNil(t, resp)
|
|
|
|
// Verify interceptors were called in the correct order
|
|
assert.Equal(t, []string{"test1", "test2"}, getMiddlewareCalls(server.recordedContext))
|
|
}
|
|
|
|
// TestServerMiddlewareToServerErrors tests failure cases for the ToServer method
|
|
// specifically related to middleware resolution and API calls.
|
|
func TestServerMiddlewareToServerErrors(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
host component.Host
|
|
config ServerConfig
|
|
errText string
|
|
}{
|
|
{
|
|
name: "extension_not_found",
|
|
host: &mockHost{
|
|
ext: map[component.ID]component.Component{},
|
|
},
|
|
config: ServerConfig{
|
|
NetAddr: confignet.AddrConfig{
|
|
Endpoint: "localhost:0",
|
|
Transport: confignet.TransportTypeTCP,
|
|
},
|
|
Middlewares: []configmiddleware.Config{
|
|
{
|
|
ID: component.MustNewID("nonexistent"),
|
|
},
|
|
},
|
|
},
|
|
errText: "failed to resolve middleware \"nonexistent\": middleware not found",
|
|
},
|
|
{
|
|
name: "get_server_options_fails",
|
|
host: &mockHost{
|
|
ext: map[component.ID]component.Component{
|
|
component.MustNewID("errormw"): extensionmiddlewaretest.NewErr(errors.New("get server options failed")),
|
|
},
|
|
},
|
|
config: ServerConfig{
|
|
NetAddr: confignet.AddrConfig{
|
|
Endpoint: "localhost:0",
|
|
Transport: confignet.TransportTypeTCP,
|
|
},
|
|
Middlewares: []configmiddleware.Config{
|
|
{
|
|
ID: component.MustNewID("errormw"),
|
|
},
|
|
},
|
|
},
|
|
errText: "get server options failed",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Test creating the server with middleware errors
|
|
server := &grpcTraceServer{}
|
|
srv, err := server.startTestServerWithHostError(t, tc.config, tc.host)
|
|
if srv != nil {
|
|
srv.Stop()
|
|
}
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), tc.errText)
|
|
})
|
|
}
|
|
}
|