grpc-go/internal/testutils/blocking_context_dialer.go

140 lines
4.0 KiB
Go

/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package testutils
import (
"context"
"net"
"sync"
"google.golang.org/grpc/grpclog"
)
var logger = grpclog.Component("testutils")
// BlockingDialer is a dialer that waits for Resume() to be called before
// dialing.
type BlockingDialer struct {
// mu protects holds.
mu sync.Mutex
// holds maps network addresses to a list of holds for that address.
holds map[string][]*Hold
}
// NewBlockingDialer returns a dialer that waits for Resume() to be called
// before dialing.
func NewBlockingDialer() *BlockingDialer {
return &BlockingDialer{
holds: make(map[string][]*Hold),
}
}
// DialContext implements a context dialer for use with grpc.WithContextDialer
// dial option for a BlockingDialer.
func (d *BlockingDialer) DialContext(ctx context.Context, addr string) (net.Conn, error) {
d.mu.Lock()
holds := d.holds[addr]
if len(holds) == 0 {
// No hold for this addr.
d.mu.Unlock()
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
}
hold := holds[0]
d.holds[addr] = holds[1:]
d.mu.Unlock()
logger.Infof("Hold %p: Intercepted connection attempt to addr %q", hold, addr)
close(hold.waitCh)
select {
case err := <-hold.blockCh:
if err != nil {
return nil, err
}
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
case <-ctx.Done():
logger.Infof("Hold %p: Connection attempt to addr %q timed out", hold, addr)
return nil, ctx.Err()
}
}
// Hold is a handle to a single connection attempt. It can be used to block,
// fail and succeed connection attempts.
type Hold struct {
// dialer is the dialer that created this hold.
dialer *BlockingDialer
// waitCh is closed when a connection attempt is received.
waitCh chan struct{}
// blockCh receives the value to return from DialContext for this connection
// attempt (nil on resume, an error on fail). It receives at most 1 value.
blockCh chan error
// addr is the address that this hold is for.
addr string
}
// Hold blocks the dialer when a connection attempt is made to the given addr.
// A hold is valid for exactly one connection attempt. Multiple holds for an
// addr can be added, and they will apply in the order that the connections are
// attempted.
func (d *BlockingDialer) Hold(addr string) *Hold {
d.mu.Lock()
defer d.mu.Unlock()
h := Hold{dialer: d, blockCh: make(chan error, 1), waitCh: make(chan struct{}), addr: addr}
d.holds[addr] = append(d.holds[addr], &h)
return &h
}
// Wait blocks until there is a connection attempt on this Hold, or the context
// expires. Return false if the context has expired, true otherwise.
func (h *Hold) Wait(ctx context.Context) bool {
logger.Infof("Hold %p: Waiting for a connection attempt to addr %q", h, h.addr)
select {
case <-ctx.Done():
return false
case <-h.waitCh:
return true
}
}
// Resume unblocks the dialer for the given addr. Either Resume or Fail must be
// called at most once on a hold. Otherwise, Resume panics.
func (h *Hold) Resume() {
logger.Infof("Hold %p: Resuming connection attempt to addr %q", h, h.addr)
h.blockCh <- nil
close(h.blockCh)
}
// Fail fails the connection attempt. Either Resume or Fail must be
// called at most once on a hold. Otherwise, Resume panics.
func (h *Hold) Fail(err error) {
logger.Infof("Hold %p: Failing connection attempt to addr %q", h, h.addr)
h.blockCh <- err
close(h.blockCh)
}
// IsStarted returns true if this hold has received a connection attempt.
func (h *Hold) IsStarted() bool {
select {
case <-h.waitCh:
return true
default:
return false
}
}