components-contrib/nameresolution/mdns/mdns_test.go

745 lines
17 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mdns
import (
"context"
"fmt"
"math"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nr "github.com/dapr/components-contrib/nameresolution"
"github.com/dapr/kit/logger"
)
const (
localhost = "127.0.0.1"
numConcurrency = 100
)
func TestInitMetadata(t *testing.T) {
tests := []struct {
missingProp string
instance nr.Instance
}{
{
"name",
nr.Instance{
Address: localhost,
DaprInternalPort: 30003,
},
},
{
"address",
nr.Instance{
AppID: "testAppID",
DaprInternalPort: 30003,
},
},
{
"port",
nr.Instance{
AppID: "testAppID",
Address: localhost,
},
},
{
"port",
nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 0,
},
},
}
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
for _, tt := range tests {
t.Run(tt.missingProp+" is missing", func(t *testing.T) {
// act
err := resolver.Init(context.Background(), nr.Metadata{Instance: tt.instance})
// assert
require.Error(t, err)
})
}
}
func TestInitRegister(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
}
func TestInitRegisterDuplicate(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
md2 := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
err = resolver.Init(context.Background(), md2)
expectedError := "app id testAppID already registered for port 1234"
require.EqualErrorf(t, err, expectedError, "Error should be: %v, got %v", expectedError, err)
}
func TestResolver(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
request := nr.ResolveRequest{ID: "testAppID"}
pt, err := resolver.ResolveID(context.Background(), request)
// assert
require.NoError(t, err)
assert.Equal(t, fmt.Sprintf("%s:1234", localhost), pt)
}
func TestResolverClose(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
request := nr.ResolveRequest{ID: "testAppID"}
pt, err := resolver.ResolveID(context.Background(), request)
// assert
require.NoError(t, err)
assert.Equal(t, fmt.Sprintf("%s:1234", localhost), pt)
// act again
err = resolver.Close()
// assert
require.NoError(t, err)
}
func TestResolverMultipleInstances(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
// register instance A
instanceAID := "A"
instanceAName := "testAppID"
instanceAAddress := localhost
instanceAPort := 1234
instanceAPQDN := fmt.Sprintf("%s:%d", instanceAAddress, instanceAPort)
err1 := resolver.registerMDNS(instanceAID, instanceAName, []string{instanceAAddress}, instanceAPort)
require.NoError(t, err1)
// register instance B
instanceBID := "B"
instanceBName := "testAppID"
instanceBAddress := localhost
instanceBPort := 5678
instanceBPQDN := fmt.Sprintf("%s:%d", instanceBAddress, instanceBPort)
err2 := resolver.registerMDNS(instanceBID, instanceBName, []string{instanceBAddress}, instanceBPort)
require.NoError(t, err2)
go resolver.startRefreshers()
// act...
request := nr.ResolveRequest{ID: "testAppID"}
// first resolution will return the first responder's address and trigger a cache refresh.
addr1, err := resolver.ResolveID(context.Background(), request)
require.NoError(t, err)
require.Contains(t, []string{instanceAPQDN, instanceBPQDN}, addr1)
// we want the resolution to be served from the cache so that it can
// be load balanced rather than received by a subscription. Therefore,
// we must sleep here long enough to allow the first browser to populate
// the cache.
time.Sleep(1 * time.Second)
// assert that when we resolve the test app id n times, we see only
// instance A and instance B and we see them each atleast m times.
instanceACount := atomic.Uint32{}
instanceBCount := atomic.Uint32{}
for i := 0; i < 100; i++ {
addr, err := resolver.ResolveID(context.Background(), request)
require.NoError(t, err)
require.Contains(t, []string{instanceAPQDN, instanceBPQDN}, addr)
if addr == instanceAPQDN {
instanceACount.Add(1)
} else if addr == instanceBPQDN {
instanceBCount.Add(1)
}
}
// 45 allows some variation in distribution.
require.Greater(t, instanceACount.Load(), uint32(45))
require.Greater(t, instanceBCount.Load(), uint32(45))
}
func TestResolverNotFound(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
// act
request := nr.ResolveRequest{ID: "testAppIDNotFound"}
pt, err := resolver.ResolveID(context.Background(), request)
// assert
expectedError := "couldn't find service: testAppIDNotFound"
require.EqualErrorf(t, err, expectedError, "Error should be: %v, got %v", expectedError, err)
assert.Equal(t, "", pt)
}
// TestResolverConcurrency is used to run concurrent tests in
// series as they rely on a shared mDNS server on the host
// machine.
func TestResolverConcurrency(t *testing.T) {
tt := []struct {
name string
test func(t *testing.T)
}{
{
name: "ResolverConcurrencyNotFound",
test: ResolverConcurrencyNotFound,
},
{
name: "ResolverConcurrencyFound",
test: ResolverConcurrencyFound,
},
{
name: "ResolverConcurrencySubscriberClear",
test: ResolverConcurrencySubsriberClear,
},
}
for _, tc := range tt {
t.Run(tc.name, tc.test)
}
}
// WARN: This is deliberately not a test function.
// This test case must be run in serial and is executed
// by the TestResolverConcurrency test function.
func ResolverConcurrencySubsriberClear(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
request := nr.ResolveRequest{ID: "testAppID"}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pt, err := resolver.ResolveID(context.Background(), request)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("%s:1234", localhost), pt)
}()
}
wg.Wait()
// Wait long enough for the background clear to occur.
time.Sleep(3 * time.Second)
require.Empty(t, resolver.subs)
}
// WARN: This is deliberately not a test function.
// This test case must be run in serial and is executed
// by the TestResolverConcurrency test function.
func ResolverConcurrencyFound(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
// register instance A
appAID := "A"
appAName := "testAppA"
appAAddress := localhost
appAPort := 1234
appABPQDN := fmt.Sprintf("%s:%d", appAAddress, appAPort)
err1 := resolver.registerMDNS(appAID, appAName, []string{appAAddress}, appAPort)
require.NoError(t, err1)
// register instance B
appBID := "B"
appBName := "testAppB"
appBAddress := localhost
appBPort := 5678
appBBPQDN := fmt.Sprintf("%s:%d", appBAddress, appBPort)
err2 := resolver.registerMDNS(appBID, appBName, []string{appBAddress}, appBPort)
require.NoError(t, err2)
// register instance C
appCID := "C"
appCName := "testAppC"
appCAddress := localhost
appCPort := 3456
appCBPQDN := fmt.Sprintf("%s:%d", appCAddress, appCPort)
err3 := resolver.registerMDNS(appCID, appCName, []string{appCAddress}, appCPort)
require.NoError(t, err3)
go resolver.startRefreshers()
// act...
wg := sync.WaitGroup{}
for i := 0; i < numConcurrency; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
var appID string
r := i % 3
if r == 0 {
appID = "testAppA"
} else if r == 1 {
appID = "testAppB"
} else {
appID = "testAppC"
}
request := nr.ResolveRequest{ID: appID}
start := time.Now()
pt, err := resolver.ResolveID(context.Background(), request)
elapsed := time.Since(start)
// assert
require.NoError(t, err)
if r == 0 {
assert.Equal(t, appABPQDN, pt)
} else if r == 1 {
assert.Equal(t, appBBPQDN, pt)
} else if r == 2 {
assert.Equal(t, appCBPQDN, pt)
}
// It should take a maximum of 3 seconds to resolve an address.
assert.Less(t, elapsed, 3*time.Second)
}(i)
}
wg.Wait()
}
// WARN: This is deliberately not a test function.
// This test case must be run in serial and is executed
// by the TestResolverConcurrency test function.
func ResolverConcurrencyNotFound(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
// act...
wg := sync.WaitGroup{}
for i := 0; i < numConcurrency; i++ {
idx := i
wg.Add(1)
go func() {
defer wg.Done()
var appID string
r := idx % 3
if r == 0 {
appID = "testAppA"
} else if r == 1 {
appID = "testAppB"
} else {
appID = "testAppC"
}
request := nr.ResolveRequest{ID: appID}
// act
start := time.Now()
pt, err := resolver.ResolveID(context.Background(), request)
elapsed := time.Since(start)
// assert
expectedError := "couldn't find service: " + appID
require.EqualErrorf(t, err, expectedError, "Error should be: %v, got %v", expectedError, err)
assert.Equal(t, "", pt)
assert.Less(t, elapsed, 2*time.Second) // browse timeout is 1 second, so we expect an error shortly after.
}()
}
wg.Wait()
}
func TestAddressListExpire(t *testing.T) {
// arrange
base := time.Now()
expired := base.Add(-60 * time.Second)
notExpired := base.Add(60 * time.Second)
addressList := &addressList{
addresses: []address{
{
ip: "expired0",
expiresAt: expired,
},
{
ip: "expired1",
expiresAt: expired,
},
{
ip: "notExpired0",
expiresAt: notExpired,
},
},
}
// act
addressList.expire()
// assert
require.Len(t, addressList.addresses, 1)
}
func TestAddressListAddNewAddress(t *testing.T) {
// arrange
expiry := time.Now().Add(60 * time.Second)
addressList := &addressList{
addresses: []address{
{
ip: "addr0",
expiresAt: expiry,
},
{
ip: "addr1",
expiresAt: expiry,
},
},
}
// act
addressList.add("addr2")
// assert
require.Len(t, addressList.addresses, 3)
require.Equal(t, "addr2", addressList.addresses[2].ip)
}
func TestAddressListAddExisitingAddress(t *testing.T) {
// arrange
expiry := time.Now().Add(10 * time.Second)
addressList := &addressList{
addresses: []address{
{
ip: "addr0",
expiresAt: expiry,
},
{
ip: "addr1",
expiresAt: expiry,
},
},
}
// act
addressList.add("addr1")
deltaSec := int(addressList.addresses[1].expiresAt.Sub(expiry).Seconds())
// assert
require.Len(t, addressList.addresses, 2)
require.Greater(t, deltaSec, 0) // Ensures expiry has been extended for existing address.
}
func TestAddressListNext(t *testing.T) {
// arrange
expiry := time.Now().Add(10 * time.Second)
addressList := &addressList{
addresses: []address{
{
ip: "addr0",
expiresAt: expiry,
},
{
ip: "addr1",
expiresAt: expiry,
},
{
ip: "addr2",
expiresAt: expiry,
},
{
ip: "addr3",
expiresAt: expiry,
},
{
ip: "addr4",
expiresAt: expiry,
},
{
ip: "addr5",
expiresAt: expiry,
},
},
}
// act & assert
require.Equal(t, "addr0", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
require.Equal(t, "addr2", *addressList.next())
require.Equal(t, "addr3", *addressList.next())
require.Equal(t, "addr4", *addressList.next())
require.Equal(t, "addr5", *addressList.next())
require.Equal(t, "addr0", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
}
func TestAddressListNextMaxCounter(t *testing.T) {
// arrange
expiry := time.Now().Add(10 * time.Second)
addressList := &addressList{
addresses: []address{
{
ip: "addr0",
expiresAt: expiry,
},
{
ip: "addr1",
expiresAt: expiry,
},
{
ip: "addr2",
expiresAt: expiry,
},
{
ip: "addr3",
expiresAt: expiry,
},
{
ip: "addr4",
expiresAt: expiry,
},
{
ip: "addr5",
expiresAt: expiry,
},
},
}
// act & assert
require.Equal(t, "addr0", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
require.Equal(t, "addr2", *addressList.next())
require.Equal(t, "addr3", *addressList.next())
addressList.counter.Store(math.MaxUint32)
require.Equal(t, "addr0", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
require.Equal(t, "addr2", *addressList.next())
}
func TestAddressListNextNoAddress(t *testing.T) {
// arrange
addressList := &addressList{
addresses: []address{},
}
// act & assert
require.Nil(t, addressList.next())
}
func TestAddressListNextWithAdd(t *testing.T) {
// arrange
expiry := time.Now().Add(10 * time.Second)
addressList := &addressList{
addresses: []address{
{
ip: "addr0",
expiresAt: expiry,
},
{
ip: "addr1",
expiresAt: expiry,
},
{
ip: "addr2",
expiresAt: expiry,
},
{
ip: "addr3",
expiresAt: expiry,
},
{
ip: "addr4",
expiresAt: expiry,
},
{
ip: "addr5",
expiresAt: expiry,
},
},
}
// act & assert
require.Equal(t, "addr0", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
require.Equal(t, "addr2", *addressList.next())
require.Equal(t, "addr3", *addressList.next())
addressList.add("addr6")
require.Equal(t, "addr4", *addressList.next())
require.Equal(t, "addr5", *addressList.next())
require.Equal(t, "addr6", *addressList.next())
require.Equal(t, "addr0", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
}
func TestAddressListNextWithExpiration(t *testing.T) {
// arrange
expiry := time.Now().Add(10 * time.Second)
expired := time.Now().Add(-60 * time.Second)
addressList := &addressList{
addresses: []address{
{
ip: "addr0",
expiresAt: expired,
},
{
ip: "addr1",
expiresAt: expiry,
},
{
ip: "addr2",
expiresAt: expired,
},
{
ip: "addr3",
expiresAt: expiry,
},
{
ip: "addr4",
expiresAt: expired,
},
{
ip: "addr5",
expiresAt: expiry,
},
},
}
// act & assert
require.Equal(t, "addr0", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
require.Equal(t, "addr2", *addressList.next())
require.Equal(t, "addr3", *addressList.next())
addressList.expire()
require.Equal(t, "addr3", *addressList.next())
require.Equal(t, "addr5", *addressList.next())
require.Equal(t, "addr1", *addressList.next())
require.Equal(t, "addr3", *addressList.next())
}
func TestUnion(t *testing.T) {
// arrange
testCases := []struct {
first []string
second []string
expected []string
}{
{
first: []string{"a", "b", "c"},
second: []string{"a", "c", "d", "e"},
expected: []string{"a", "b", "c", "d", "e"},
},
{
first: []string{"a", "b", "c"},
second: []string{"d", "e", "f", "g"},
expected: []string{"a", "b", "c", "d", "e", "f", "g"},
},
{
first: []string{"a", "b"},
second: []string{"a", "b"},
expected: []string{"a", "b"},
},
{
first: []string{"a", "b"},
second: []string{"b", "a"},
expected: []string{"a", "b"},
},
}
for _, tt := range testCases {
// act
union := union(tt.first, tt.second)
// assert
var matches int
for _, i1 := range tt.expected {
for _, i2 := range union {
if i1 == i2 {
matches++
}
}
}
require.Len(t, tt.expected, matches)
}
}