diff --git a/Makefile b/Makefile index 1fd73dc73..a8d211f41 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ endif ################################################################################ .PHONY: test test: - CGO_ENABLED=$(CGO) go test ./... $(COVERAGE_OPTS) $(BUILDMODE) + CGO_ENABLED=$(CGO) go test ./... $(COVERAGE_OPTS) $(BUILDMODE) --timeout=15m ################################################################################ # Target: lint # diff --git a/nameresolution/mdns/mdns.go b/nameresolution/mdns/mdns.go index f51f215d9..26893428c 100644 --- a/nameresolution/mdns/mdns.go +++ b/nameresolution/mdns/mdns.go @@ -31,9 +31,16 @@ import ( ) const ( - // firstOnlyTimeout is the timeout used when + // browseOneTimeout is the timeout used when // browsing for the first response to a single app id. - firstOnlyTimeout = time.Second * 1 + browseOneTimeout = time.Second * 1 + // subscriberTimeout is the timeout used when + // subscribing to the first browser returning a response. + subscriberTimeout = time.Second * 2 + // subscriberCleanupWait is the time to wait before + // performing a clean up of a subscriber pool. This + // MUST be greater than subscriberTimeout. + subscriberCleanupWait = time.Millisecond * 2500 // refreshTimeout is the timeout used when // browsing for any responses to a single app id. refreshTimeout = time.Second * 3 @@ -122,20 +129,88 @@ func (a *addressList) next() *string { return &addr.ip } +// SubscriberPool is used to manage +// a pool of subscribers for a given app id. +// 'Once' belongs to the first subscriber as +// it is their responsibility to fetch the +// address associated with the app id and +// publish it to the other subscribers. +// WARN: pools are not thread safe and intended +// to be accessed only when using subMu lock. +type SubscriberPool struct { + Once *sync.Once + Subscribers []Subscriber +} + +func NewSubscriberPool(w Subscriber) *SubscriberPool { + return &SubscriberPool{ + Once: new(sync.Once), + Subscribers: []Subscriber{w}, + } +} + +func (p *SubscriberPool) Add(sub Subscriber) { + p.Subscribers = append(p.Subscribers, sub) +} + +type Subscriber struct { + AddrChan chan string + ErrChan chan error +} + +func (s *Subscriber) Close() { + close(s.AddrChan) + close(s.ErrChan) +} + +func NewSubscriber() Subscriber { + return Subscriber{ + // ID is assigned by the pool. + AddrChan: make(chan string, 1), + ErrChan: make(chan error, 1), + } +} + // NewResolver creates the instance of mDNS name resolver. -func NewResolver(logger logger.Logger) nameresolution.Resolver { - r := &resolver{ +func NewResolver(logger logger.Logger) *Resolver { + r := &Resolver{ + subs: make(map[string]*SubscriberPool), appAddressesIPv4: make(map[string]*addressList), appAddressesIPv6: make(map[string]*addressList), - refreshChan: make(chan string), - logger: logger, + // an app id for every app that is resolved can be pushed + // onto this channel. We don't want to block the sender as + // they are resolving the app id as part of the service invocation. + // Instead we use a buffered channel to balance the load whilst + // the background refreshes are being performed. Once this buffer + // becomes full, the sender will block and service invocation + // will be delayed. We don't expect a single app to resolve + // too many other app IDs so we set this to a sensible value + // to avoid over allocating the buffer. + refreshChan: make(chan string, 36), + // registrations channel to signal the resolver to + // stop serving queries for registered app ids. + registrations: make(map[string]chan struct{}), + // shutdownRefresh channel to signal to stop on-demand refreshes. + shutdownRefresh: make(chan struct{}, 1), + // shutdownRefreshPeriodic channel to signal to stop periodic refreshes. + shutdownRefreshPeridoic: make(chan struct{}, 1), + logger: logger, } // refresh app addresses on demand. go func() { - for appID := range r.refreshChan { - if err := r.refreshApp(context.Background(), appID); err != nil { - r.logger.Warnf(err.Error()) + for { + refreshCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + select { + case appID := <-r.refreshChan: + if err := r.refreshApp(refreshCtx, appID); err != nil { + r.logger.Warnf(err.Error()) + } + case <-r.shutdownRefresh: + r.logger.Debug("stopping on demand cache refreshes.") + return } } }() @@ -143,10 +218,17 @@ func NewResolver(logger logger.Logger) nameresolution.Resolver { // refresh all app addresses periodically. go func() { for { - time.Sleep(refreshInterval) + refreshCtx, cancel := context.WithCancel(context.Background()) + defer cancel() - if err := r.refreshAllApps(context.Background()); err != nil { - r.logger.Warnf(err.Error()) + select { + case <-time.After(refreshInterval): + if err := r.refreshAllApps(refreshCtx); err != nil { + r.logger.Warnf(err.Error()) + } + case <-r.shutdownRefreshPeridoic: + r.logger.Debug("stopping periodic cache refreshes.") + return } } }() @@ -154,17 +236,37 @@ func NewResolver(logger logger.Logger) nameresolution.Resolver { return r } -type resolver struct { +type Resolver struct { + // subscribers are used when multiple callers + // request the same app ID before it is cached. + // Only 1 will fetch the address, the rest will + // subscribe for the address or an error. + subs map[string]*SubscriberPool + subMu sync.RWMutex + // IPv4 cache is used to store IPv4 addresses. ipv4Mu sync.RWMutex appAddressesIPv4 map[string]*addressList + // IPv6 cache is used to store IPv6 addresses. ipv6Mu sync.RWMutex appAddressesIPv6 map[string]*addressList - refreshChan chan string - logger logger.Logger + // refreshChan is used to trigger background refreshes + // of app IDs in case there are more than 1 server + // hosting that app id. + refreshChan chan string + // registrations are the app ids that have been + // registered with this resolver. A single resolver + // may serve multiple app ids - although this is + // expected to be 1 when initialized by the dapr runtime. + registrationMu sync.RWMutex + registrations map[string]chan struct{} + // shutdown refreshes. + shutdownRefresh chan struct{} + shutdownRefreshPeridoic chan struct{} + logger logger.Logger } // Init registers service for mDNS. -func (m *resolver) Init(metadata nameresolution.Metadata) error { +func (m *Resolver) Init(metadata nameresolution.Metadata) error { var appID string var hostAddress string var ok bool @@ -194,17 +296,51 @@ func (m *resolver) Init(metadata nameresolution.Metadata) error { } err = m.registerMDNS(instanceID, appID, []string{hostAddress}, int(port)) - if err == nil { - m.logger.Infof("local service entry announced: %s -> %s:%d", appID, hostAddress, port) + if err != nil { + return err } - return err + m.logger.Infof("local service entry announced: %s -> %s:%d", appID, hostAddress, port) + return nil } -func (m *resolver) registerMDNS(instanceID string, appID string, ips []string, port int) error { +// Close is not formally part of the name resolution interface as proposed +// in https://github.com/dapr/components-contrib/issues/1472 but this is +// used in the tests to clean up the mDNS registration. +func (m *Resolver) Close() error { + // stop all app ids currently being served from this resolver. + m.registrationMu.Lock() + defer m.registrationMu.Unlock() + for _, doneChan := range m.registrations { + doneChan <- struct{}{} + close(doneChan) + } + // clear the registrations map. + m.registrations = make(map[string]chan struct{}) + + // stop all refresh loops + m.shutdownRefresh <- struct{}{} + m.shutdownRefreshPeridoic <- struct{}{} + + return nil +} + +func (m *Resolver) registerMDNS(instanceID string, appID string, ips []string, port int) error { started := make(chan bool, 1) var err error + // Register the app id with the resolver. + done := make(chan struct{}) + key := fmt.Sprintf("%s:%d", appID, port) // WARN: we do not support unique ips. + m.registrationMu.Lock() + _, exists := m.registrations[key] + if exists { + m.registrationMu.Unlock() + return fmt.Errorf("app id %s already registered for port %d", appID, port) + } + m.registrations[key] = done + m.registrationMu.Unlock() + go func() { var server *zeroconf.Server @@ -230,11 +366,18 @@ func (m *resolver) registerMDNS(instanceID string, appID string, ips []string, p } started <- true - // Wait until it gets SIGTERM event. sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt, syscall.SIGTERM) - <-sig + // wait until either a SIGTERM or done is received. + select { + case <-sig: + m.logger.Debugf("received SIGTERM signal, shutting down...") + case <-done: + m.logger.Debugf("received done signal , shutting down...") + } + + m.logger.Info("stopping mDNS server for app id: ", appID) server.Shutdown() }() @@ -244,7 +387,7 @@ func (m *resolver) registerMDNS(instanceID string, appID string, ips []string, p } // ResolveID resolves name to address via mDNS. -func (m *resolver) ResolveID(req nameresolution.ResolveRequest) (string, error) { +func (m *Resolver) ResolveID(req nameresolution.ResolveRequest) (string, error) { // check for cached IPv4 addresses for this app id first. if addr := m.nextIPv4Address(req.ID); addr != nil { return *addr, nil @@ -258,83 +401,226 @@ func (m *resolver) ResolveID(req nameresolution.ResolveRequest) (string, error) // cache miss, fallback to browsing the network for addresses. m.logger.Debugf("no mDNS address found in cache, browsing network for app id %s", req.ID) - // get the first address we receive... - addr, err := m.browseFirstOnly(context.Background(), req.ID) - if err == nil { - // ...and trigger a background refresh for any additional addresses. - m.refreshChan <- req.ID + // create a new sub which will wait for an address or error. + sub := NewSubscriber() + + // add the sub to the pool of subs for this app id. + m.subMu.Lock() + appIDSubs, exists := m.subs[req.ID] + if !exists { + // WARN: must set appIDSubs variable for use below. + appIDSubs = NewSubscriberPool(sub) + m.subs[req.ID] = appIDSubs + } else { + appIDSubs.Add(sub) + } + m.subMu.Unlock() + + // only one subscriber per pool will perform the first browse for the + // requested app id. The rest will subscribe for an address or error. + var once *sync.Once + var published chan struct{} + ctx, cancel := context.WithTimeout(context.Background(), browseOneTimeout) + defer cancel() + appIDSubs.Once.Do(func() { + published = make(chan struct{}) + m.browseOne(ctx, req.ID, published) + + // once will only be set for the first browser. + once = new(sync.Once) + }) + + // if subscribed to the app id but the first browser has already + // read the subscribers to send the publish event then we may + // not receive the address or error. The first browser will always + // update the cache before reading the subscribers so we can + // recheck the cache here to make sure we get the address or error + // regardless. This block should not be executed by the first + // browser as they must wait on the published channel and perform + // the cleanup before returning. + if once == nil { + if addr := m.nextIPv4Address(req.ID); addr != nil { + return *addr, nil + } + if addr := m.nextIPv6Address(req.ID); addr != nil { + return *addr, nil + } } - return addr, err + select { + case addr := <-sub.AddrChan: + // only 1 subscriber should have set the once var so + // this block should only get invoked once too. + if once != nil { + once.Do(func() { + // trigger the background refresh for additional addresses. + // WARN: this can block if refreshChan is full. + m.refreshChan <- req.ID + + // block on the published channel as this signals that we have + // published the address to all other subscribers before we return. + <-published + + // AddrChan is a buffered channel and we cannot guarantee that + // all subscribers will read the value even though we have published. + // Therefore it is not safe to remove the subscribers until after + // any subscribers would have timed out so we run a delayed background + // cleanup. + go func() { + time.Sleep(subscriberCleanupWait) + m.subMu.Lock() + delete(m.subs, req.ID) + m.subMu.Unlock() + }() + }) + } + return addr, nil + case err := <-sub.ErrChan: + if once != nil { + once.Do(func() { + // block on the published channel as this signals that we have + // published the error to all other subscribers before we return. + <-published + + // ErrChan is a buffered channel and we cannot guarantee that + // all subscribers will read the value even though we have published. + // Therefore it is not safe to remove the subscribers until after + // any subscribers would have timed out so we run a delayed background + // cleanup. + go func() { + time.Sleep(subscriberCleanupWait) + m.subMu.Lock() + delete(m.subs, req.ID) + m.subMu.Unlock() + }() + }) + } + return "", err + case <-time.After(subscriberTimeout): + // If no address or error has been received + // within the timeout, we will check the cache again and + // if no address is present we will return an error. + if addr := m.nextIPv4Address(req.ID); addr != nil { + return *addr, nil + } + if addr := m.nextIPv6Address(req.ID); addr != nil { + return *addr, nil + } + return "", fmt.Errorf("timeout waiting for address for app id %s", req.ID) + } } -// browseFirstOnly will perform a mDNS network browse for an address +// browseOne will perform a mDNS network browse for an address // matching the provided app id. It will return the first address it // receives and stop browsing for any more. -func (m *resolver) browseFirstOnly(ctx context.Context, appID string) (string, error) { - var addr string +// This must be called in a sync.Once block to avoid concurrency issues. +func (m *Resolver) browseOne(ctx context.Context, appID string, published chan struct{}) { + go func() { + var addr string - ctx, cancel := context.WithTimeout(ctx, firstOnlyTimeout) - defer cancel() + browseCtx, cancel := context.WithCancel(ctx) + defer cancel() - // onFirst will be invoked on the first address received. - // Due to the asynchronous nature of cancel() there - // is no guarantee that this will ONLY be invoked on the - // first address. Ensure that multiple invocations of this - // function are safe. - onFirst := func(ip string) { - addr = ip - cancel() // cancel to stop browsing. + // onFirst will be invoked on the first address received. + // Due to the asynchronous nature of cancel() there + // is no guarantee that this will ONLY be invoked on the + // first address. Ensure that multiple invocations of this + // function are safe. + onFirst := func(ip string) { + addr = ip + cancel() // cancel to stop browsing. + } + + m.logger.Debugf("Browsing for first mDNS address for app id %s", appID) + + err := m.browse(browseCtx, appID, onFirst) + if err != nil { + m.pubErrToSubs(appID, err) + + published <- struct{}{} // signal that all subscribers have been notified. + return + } + + // wait for the context to be canceled or time out. + <-browseCtx.Done() + + if errors.Is(browseCtx.Err(), context.Canceled) { + // expect this when we've found an address and canceled the browse. + m.logger.Debugf("Browsing for first mDNS address for app id %s canceled.", appID) + } else if errors.Is(browseCtx.Err(), context.DeadlineExceeded) { + // expect this when we've been unable to find the first address before the timeout. + m.logger.Debugf("Browsing for first mDNS address for app id %s timed out.", appID) + } + + // if onFirst has been invoked then we should have an address. + if addr == "" { + m.pubErrToSubs(appID, fmt.Errorf("couldn't find service: %s", appID)) + + published <- struct{}{} // signal that all subscribers have been notified. + return + } + + m.pubAddrToSubs(appID, addr) + + published <- struct{}{} // signal that all subscribers have been notified. + }() +} + +func (m *Resolver) pubErrToSubs(reqID string, err error) { + m.subMu.RLock() + defer m.subMu.RUnlock() + pool, ok := m.subs[reqID] + if !ok { + // we would always expect atleast 1 subscriber for this reqID. + m.logger.Warnf("no subscribers found for app id %s", reqID) + return } - - m.logger.Debugf("Browsing for first mDNS address for app id %s", appID) - - err := m.browse(ctx, appID, onFirst) - if err != nil { - return "", err + for _, subscriber := range pool.Subscribers { + // ErrChan is a buffered channel so this is non blocking unless full. + subscriber.ErrChan <- err + subscriber.Close() } +} - // wait for the context to be canceled or time out. - <-ctx.Done() - - if errors.Is(ctx.Err(), context.Canceled) { - // expect this when we've found an address and canceled the browse. - m.logger.Debugf("Browsing for first mDNS address for app id %s canceled.", appID) - } else if errors.Is(ctx.Err(), context.DeadlineExceeded) { - // expect this when we've been unable to find the first address before the timeout. - m.logger.Debugf("Browsing for first mDNS address for app id %s timed out.", appID) +func (m *Resolver) pubAddrToSubs(reqID string, addr string) { + m.subMu.RLock() + defer m.subMu.RUnlock() + pool, ok := m.subs[reqID] + if !ok { + // we would always expect atleast 1 subscriber for this reqID. + m.logger.Warnf("no subscribers found for app id %s", reqID) + return } - - if addr == "" { - return "", fmt.Errorf("couldn't find service: %s", appID) + for _, subscriber := range pool.Subscribers { + // AddrChan is a buffered channel so this is non blocking unless full. + subscriber.AddrChan <- addr + subscriber.Close() } - - return addr, nil } // refreshApp will perform a mDNS network browse for a provided // app id. This function is blocking. -func (m *resolver) refreshApp(ctx context.Context, appID string) error { +func (m *Resolver) refreshApp(refreshCtx context.Context, appID string) error { if appID == "" { return nil } m.logger.Debugf("Refreshing mDNS addresses for app id %s.", appID) - ctx, cancel := context.WithTimeout(ctx, refreshTimeout) + refreshCtx, cancel := context.WithTimeout(refreshCtx, refreshTimeout) defer cancel() - if err := m.browse(ctx, appID, nil); err != nil { + if err := m.browse(refreshCtx, appID, nil); err != nil { return err } // wait for the context to be canceled or time out. - <-ctx.Done() + <-refreshCtx.Done() - if errors.Is(ctx.Err(), context.Canceled) { + if errors.Is(refreshCtx.Err(), context.Canceled) { // this is not expected, investigate why context was canceled. m.logger.Warnf("Refreshing mDNS addresses for app id %s canceled.", appID) - } else if errors.Is(ctx.Err(), context.DeadlineExceeded) { + } else if errors.Is(refreshCtx.Err(), context.DeadlineExceeded) { // expect this when our browse has timedout. m.logger.Debugf("Refreshing mDNS addresses for app id %s timed out.", appID) } @@ -344,7 +630,7 @@ func (m *resolver) refreshApp(ctx context.Context, appID string) error { // refreshAllApps will perform a mDNS network browse for each address // currently in the cache. This function is blocking. -func (m *resolver) refreshAllApps(ctx context.Context) error { +func (m *Resolver) refreshAllApps(ctx context.Context) error { m.logger.Debug("Refreshing all mDNS addresses.") // check if we have any IPv4 or IPv6 addresses @@ -373,7 +659,10 @@ func (m *resolver) refreshAllApps(ctx context.Context) error { go func(a string) { defer wg.Done() - m.refreshApp(ctx, a) + err := m.refreshApp(ctx, a) + if err != nil { + m.logger.Warnf("error refreshing mDNS addresses for app id %s: %v", a, err) + } }(appID) } @@ -384,73 +673,81 @@ func (m *resolver) refreshAllApps(ctx context.Context) error { } // browse will perform a non-blocking mdns network browse for the provided app id. -func (m *resolver) browse(ctx context.Context, appID string, onEach func(ip string)) error { +func (m *Resolver) browse(ctx context.Context, appID string, onEach func(ip string)) error { resolver, err := zeroconf.NewResolver(nil) if err != nil { - return fmt.Errorf("failed to initialize resolver: %e", err) + return fmt.Errorf("failed to initialize resolver: %w", err) } entries := make(chan *zeroconf.ServiceEntry) + handleEntry := func(entry *zeroconf.ServiceEntry) { + for _, text := range entry.Text { + if text != appID { + m.logger.Debugf("mDNS response doesn't match app id %s, skipping.", appID) + + break + } + + m.logger.Debugf("mDNS response for app id %s received.", appID) + + hasIPv4Address := len(entry.AddrIPv4) > 0 + hasIPv6Address := len(entry.AddrIPv6) > 0 + + if !hasIPv4Address && !hasIPv6Address { + m.logger.Debugf("mDNS response for app id %s doesn't contain any IPv4 or IPv6 addresses, skipping.", appID) + + break + } + + var addr string + port := entry.Port + + // TODO: we currently only use the first IPv4 and IPv6 address. + // We should understand the cases in which additional addresses + // are returned and whether we need to support them. + if hasIPv4Address { + addr = fmt.Sprintf("%s:%d", entry.AddrIPv4[0].String(), port) + m.addAppAddressIPv4(appID, addr) + } + if hasIPv6Address { + addr = fmt.Sprintf("%s:%d", entry.AddrIPv6[0].String(), port) + m.addAppAddressIPv6(appID, addr) + } + + if onEach != nil { + onEach(addr) // invoke callback. + } + } + } + // handle each service entry returned from the mDNS browse. go func(results <-chan *zeroconf.ServiceEntry) { for { select { + case entry := <-results: + if entry == nil { + break + } + handleEntry(entry) case <-ctx.Done(): + // drain the results before exiting. + for len(results) > 0 { + handleEntry(<-results) + } + if errors.Is(ctx.Err(), context.Canceled) { m.logger.Debugf("mDNS browse for app id %s canceled.", appID) } else if errors.Is(ctx.Err(), context.DeadlineExceeded) { m.logger.Debugf("mDNS browse for app id %s timed out.", appID) } - return - case entry := <-results: - if entry == nil { - break - } - - for _, text := range entry.Text { - if text != appID { - m.logger.Debugf("mDNS response doesn't match app id %s, skipping.", appID) - - break - } - - m.logger.Debugf("mDNS response for app id %s received.", appID) - - hasIPv4Address := len(entry.AddrIPv4) > 0 - hasIPv6Address := len(entry.AddrIPv6) > 0 - - if !hasIPv4Address && !hasIPv6Address { - m.logger.Debugf("mDNS response for app id %s doesn't contain any IPv4 or IPv6 addresses, skipping.", appID) - - break - } - - var addr string - port := entry.Port - - // TODO: We currently only use the first IPv4 and IPv6 address. - // We should understand the cases in which additional addresses - // are returned and whether we need to support them. - if hasIPv4Address { - addr = fmt.Sprintf("%s:%d", entry.AddrIPv4[0].String(), port) - m.addAppAddressIPv4(appID, addr) - } - if hasIPv6Address { - addr = fmt.Sprintf("%s:%d", entry.AddrIPv6[0].String(), port) - m.addAppAddressIPv6(appID, addr) - } - - if onEach != nil { - onEach(addr) // invoke callback. - } - } + return // stop listening for results. } } }(entries) if err = resolver.Browse(ctx, appID, "local.", entries); err != nil { - return fmt.Errorf("failed to browse: %s", err.Error()) + return fmt.Errorf("failed to browse: %w", err) } return nil @@ -458,7 +755,7 @@ func (m *resolver) browse(ctx context.Context, appID string, onEach func(ip stri // addAppAddressIPv4 adds an IPv4 address to the // cache for the provided app id. -func (m *resolver) addAppAddressIPv4(appID string, addr string) { +func (m *Resolver) addAppAddressIPv4(appID string, addr string) { m.ipv4Mu.Lock() defer m.ipv4Mu.Unlock() @@ -472,7 +769,7 @@ func (m *resolver) addAppAddressIPv4(appID string, addr string) { // addAppIPv4Address adds an IPv6 address to the // cache for the provided app id. -func (m *resolver) addAppAddressIPv6(appID string, addr string) { +func (m *Resolver) addAppAddressIPv6(appID string, addr string) { m.ipv6Mu.Lock() defer m.ipv6Mu.Unlock() @@ -486,7 +783,7 @@ func (m *resolver) addAppAddressIPv6(appID string, addr string) { // getAppIDsIPv4 returns a list of the current IPv4 app IDs. // This method uses expire on read to evict expired addreses. -func (m *resolver) getAppIDsIPv4() []string { +func (m *Resolver) getAppIDsIPv4() []string { m.ipv4Mu.RLock() defer m.ipv4Mu.RUnlock() @@ -503,7 +800,7 @@ func (m *resolver) getAppIDsIPv4() []string { // getAppIDsIPv6 returns a list of the known IPv6 app IDs. // This method uses expire on read to evict expired addreses. -func (m *resolver) getAppIDsIPv6() []string { +func (m *Resolver) getAppIDsIPv6() []string { m.ipv6Mu.RLock() defer m.ipv6Mu.RUnlock() @@ -520,13 +817,13 @@ func (m *resolver) getAppIDsIPv6() []string { // getAppIDs returns a list of app ids currently in // the cache, ensuring expired addresses are evicted. -func (m *resolver) getAppIDs() []string { +func (m *Resolver) getAppIDs() []string { return union(m.getAppIDsIPv4(), m.getAppIDsIPv6()) } // nextIPv4Address returns the next IPv4 address for // the provided app id from the cache. -func (m *resolver) nextIPv4Address(appID string) *string { +func (m *Resolver) nextIPv4Address(appID string) *string { m.ipv4Mu.RLock() defer m.ipv4Mu.RUnlock() addrList, exists := m.appAddressesIPv4[appID] @@ -544,7 +841,7 @@ func (m *resolver) nextIPv4Address(appID string) *string { // nextIPv6Address returns the next IPv6 address for // the provided app id from the cache. -func (m *resolver) nextIPv6Address(appID string) *string { +func (m *Resolver) nextIPv6Address(appID string) *string { m.ipv6Mu.RLock() defer m.ipv6Mu.RUnlock() addrList, exists := m.appAddressesIPv6[appID] diff --git a/nameresolution/mdns/mdns_test.go b/nameresolution/mdns/mdns_test.go index cb7206339..295e507c9 100644 --- a/nameresolution/mdns/mdns_test.go +++ b/nameresolution/mdns/mdns_test.go @@ -15,6 +15,7 @@ package mdns import ( "fmt" + "sync" "testing" "time" @@ -25,7 +26,12 @@ import ( "github.com/dapr/kit/logger" ) -func TestInit(t *testing.T) { +const ( + localhost = "127.0.0.1" + numConcurrency = 100 +) + +func TestInitMetadata(t *testing.T) { tests := []struct { missingProp string props map[string]string @@ -33,7 +39,7 @@ func TestInit(t *testing.T) { { "name", map[string]string{ - nr.MDNSInstanceAddress: "127.0.0.1", + nr.MDNSInstanceAddress: localhost, nr.MDNSInstancePort: "30003", }, }, @@ -48,14 +54,14 @@ func TestInit(t *testing.T) { "port", map[string]string{ nr.MDNSInstanceName: "testAppID", - nr.MDNSInstanceAddress: "127.0.0.1", + nr.MDNSInstanceAddress: localhost, }, }, { "port", map[string]string{ nr.MDNSInstanceName: "testAppID", - nr.MDNSInstanceAddress: "127.0.0.1", + nr.MDNSInstanceAddress: localhost, nr.MDNSInstancePort: "abcd", }, }, @@ -63,6 +69,7 @@ func TestInit(t *testing.T) { // arrange resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() for _, tt := range tests { t.Run(tt.missingProp+" is missing", func(t *testing.T) { @@ -75,12 +82,51 @@ func TestInit(t *testing.T) { } } +func TestInitRegister(t *testing.T) { + // arrange + resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() + md := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: "testAppID", + nr.MDNSInstanceAddress: localhost, + nr.MDNSInstancePort: "1234", + }} + + // act + err := resolver.Init(md) + require.NoError(t, err) +} + +func TestInitRegisterDuplicate(t *testing.T) { + // arrange + resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() + md := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: "testAppID", + nr.MDNSInstanceAddress: localhost, + nr.MDNSInstancePort: "1234", + }} + md2 := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: "testAppID", + nr.MDNSInstanceAddress: localhost, + nr.MDNSInstancePort: "1234", + }} + + // act + err := resolver.Init(md) + require.NoError(t, err) + err = resolver.Init(md2) + expectedError := "app id testAppID already registered for port 1234" + require.EqualErrorf(t, err, expectedError, "Error should be: %v, got %v", expectedError, err) +} + func TestResolver(t *testing.T) { // arrange resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() md := nr.Metadata{Properties: map[string]string{ nr.MDNSInstanceName: "testAppID", - nr.MDNSInstanceAddress: "127.0.0.1", + nr.MDNSInstanceAddress: localhost, nr.MDNSInstancePort: "1234", }} @@ -93,17 +139,45 @@ func TestResolver(t *testing.T) { // assert require.NoError(t, err) - assert.Equal(t, "127.0.0.1:1234", pt) + assert.Equal(t, fmt.Sprintf("%s:1234", localhost), pt) +} + +func TestResolverClose(t *testing.T) { + // arrange + resolver := NewResolver(logger.NewLogger("test")) + md := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: "testAppID", + nr.MDNSInstanceAddress: localhost, + nr.MDNSInstancePort: "1234", + }} + + // act + err := resolver.Init(md) + require.NoError(t, err) + + request := nr.ResolveRequest{ID: "testAppID"} + pt, err := resolver.ResolveID(request) + + // assert + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("%s:1234", localhost), pt) + + // act again + err = resolver.Close() + + // assert + require.NoError(t, err) } func TestResolverMultipleInstances(t *testing.T) { - // arrange... + // arrange resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() // register instance A instanceAID := "A" instanceAName := "testAppID" - instanceAAddress := "127.0.0.1" + instanceAAddress := localhost instanceAPort := "1234" instanceAPQDN := fmt.Sprintf("%s:%s", instanceAAddress, instanceAPort) @@ -119,7 +193,7 @@ func TestResolverMultipleInstances(t *testing.T) { // register instance B instanceBID := "B" instanceBName := "testAppID" - instanceBAddress := "127.0.0.1" + instanceBAddress := localhost instanceBPort := "5678" instanceBPQDN := fmt.Sprintf("%s:%s", instanceBAddress, instanceBPort) @@ -140,7 +214,10 @@ func TestResolverMultipleInstances(t *testing.T) { require.NoError(t, err) require.Contains(t, []string{instanceAPQDN, instanceBPQDN}, addr1) - // delay long enough for the background address cache to populate. + // we want the resolution to be served from the cache so that it can + // be load balanced rather than received by a subscription. Therefore, + // we must sleep here long enough to allow the first browser to populate + // the cache. time.Sleep(1 * time.Second) // assert that when we resolve the test app id n times, we see only @@ -162,6 +239,223 @@ func TestResolverMultipleInstances(t *testing.T) { require.Greater(t, instanceBCount, 45) } +func TestResolverNotFound(t *testing.T) { + // arrange + resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() + + // act + request := nr.ResolveRequest{ID: "testAppIDNotFound"} + pt, err := resolver.ResolveID(request) + + // assert + expectedError := "couldn't find service: testAppIDNotFound" + require.EqualErrorf(t, err, expectedError, "Error should be: %v, got %v", expectedError, err) + assert.Equal(t, "", pt) +} + +// TestResolverConcurrency is used to run concurrent tests in +// series as they rely on a shared mDNS server on the host +// machine. +func TestResolverConcurrency(t *testing.T) { + tt := []struct { + name string + test func(t *testing.T) + }{ + { + name: "ResolverConcurrencyNotFound", + test: ResolverConcurrencyNotFound, + }, + { + name: "ResolverConcurrencyFound", + test: ResolverConcurrencyFound, + }, + { + name: "ResolverConcurrencySubscriberClear", + test: ResolverConcurrencySubsriberClear, + }, + } + + for _, tc := range tt { + t.Run(tc.name, tc.test) + } +} + +func ResolverConcurrencySubsriberClear(t *testing.T) { + // arrange + resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() + md := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: "testAppID", + nr.MDNSInstanceAddress: localhost, + nr.MDNSInstancePort: "1234", + }} + + // act + err := resolver.Init(md) + require.NoError(t, err) + + request := nr.ResolveRequest{ID: "testAppID"} + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + pt, err := resolver.ResolveID(request) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%s:1234", localhost), pt) + }() + } + + wg.Wait() + + // Wait long enough for the background clear to occur. + time.Sleep(3 * time.Second) + require.Equal(t, 0, len(resolver.subs)) +} + +// WARN: This is deliberately not a test function. +// This test case must be run in serial and is executed +// by the TestResolverConcurrency test function. +func ResolverConcurrencyFound(t *testing.T) { + // arrange + resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() + + // register instance A + appAID := "A" + appAName := "testAppA" + appAAddress := localhost + appAPort := "1234" + appAPQDN := fmt.Sprintf("%s:%s", appAAddress, appAPort) + + appA := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: appAName, + nr.MDNSInstanceAddress: appAAddress, + nr.MDNSInstancePort: appAPort, + nr.MDNSInstanceID: appAID, + }} + err1 := resolver.Init(appA) + require.NoError(t, err1) + + // register instance B + appBID := "B" + appBName := "testAppB" + appBAddress := localhost + appBPort := "5678" + appBBPQDN := fmt.Sprintf("%s:%s", appBAddress, appBPort) + + appB := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: appBName, + nr.MDNSInstanceAddress: appBAddress, + nr.MDNSInstancePort: appBPort, + nr.MDNSInstanceID: appBID, + }} + err2 := resolver.Init(appB) + require.NoError(t, err2) + + // register instance C + appCID := "C" + appCName := "testAppC" + appCAddress := localhost + appCPort := "3456" + appCBPQDN := fmt.Sprintf("%s:%s", appCAddress, appCPort) + + appC := nr.Metadata{Properties: map[string]string{ + nr.MDNSInstanceName: appCName, + nr.MDNSInstanceAddress: appCAddress, + nr.MDNSInstancePort: appCPort, + nr.MDNSInstanceID: appCID, + }} + err3 := resolver.Init(appC) + require.NoError(t, err3) + + // act... + wg := sync.WaitGroup{} + for i := 0; i < numConcurrency; i++ { + idx := i + wg.Add(1) + go func() { + defer wg.Done() + + var appID string + r := idx % 3 + if r == 0 { + appID = "testAppA" + } else if r == 1 { + appID = "testAppB" + } else { + appID = "testAppC" + } + request := nr.ResolveRequest{ID: appID} + + start := time.Now() + pt, err := resolver.ResolveID(request) + elapsed := time.Since(start) + // assert + require.NoError(t, err) + if r == 0 { + assert.Equal(t, appAPQDN, pt) + } else if r == 1 { + assert.Equal(t, appBBPQDN, pt) + } else if r == 2 { + assert.Equal(t, appCBPQDN, pt) + } + + // It should tax a maximum of 3 seconds to + // resolve an address. + assert.Less(t, elapsed, 3*time.Second) + }() + } + + wg.Wait() +} + +// WARN: This is deliberately not a test function. +// This test case must be run in serial and is executed +// by the TestResolverConcurrency test function. +func ResolverConcurrencyNotFound(t *testing.T) { + // arrange + resolver := NewResolver(logger.NewLogger("test")) + defer resolver.Close() + + // act... + wg := sync.WaitGroup{} + for i := 0; i < numConcurrency; i++ { + idx := i + wg.Add(1) + go func() { + defer wg.Done() + + var appID string + r := idx % 3 + if r == 0 { + appID = "testAppA" + } else if r == 1 { + appID = "testAppB" + } else { + appID = "testAppC" + } + request := nr.ResolveRequest{ID: appID} + + // act + start := time.Now() + pt, err := resolver.ResolveID(request) + elapsed := time.Since(start) + + // assert + expectedError := "couldn't find service: " + appID + require.EqualErrorf(t, err, expectedError, "Error should be: %v, got %v", expectedError, err) + assert.Equal(t, "", pt) + assert.Less(t, elapsed, 2*time.Second) // browse timeout is 1 second, so we expect an error shortly after. + }() + } + + wg.Wait() +} + func TestAddressListExpire(t *testing.T) { // arrange base := time.Now()