mirror of https://github.com/grpc/grpc-go.git
140 lines
4.0 KiB
Go
140 lines
4.0 KiB
Go
/*
|
|
*
|
|
* Copyright 2024 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package testutils
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"sync"
|
|
|
|
"google.golang.org/grpc/grpclog"
|
|
)
|
|
|
|
var logger = grpclog.Component("testutils")
|
|
|
|
// BlockingDialer is a dialer that waits for Resume() to be called before
|
|
// dialing.
|
|
type BlockingDialer struct {
|
|
// mu protects holds.
|
|
mu sync.Mutex
|
|
// holds maps network addresses to a list of holds for that address.
|
|
holds map[string][]*Hold
|
|
}
|
|
|
|
// NewBlockingDialer returns a dialer that waits for Resume() to be called
|
|
// before dialing.
|
|
func NewBlockingDialer() *BlockingDialer {
|
|
return &BlockingDialer{
|
|
holds: make(map[string][]*Hold),
|
|
}
|
|
}
|
|
|
|
// DialContext implements a context dialer for use with grpc.WithContextDialer
|
|
// dial option for a BlockingDialer.
|
|
func (d *BlockingDialer) DialContext(ctx context.Context, addr string) (net.Conn, error) {
|
|
d.mu.Lock()
|
|
holds := d.holds[addr]
|
|
if len(holds) == 0 {
|
|
// No hold for this addr.
|
|
d.mu.Unlock()
|
|
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
|
|
}
|
|
hold := holds[0]
|
|
d.holds[addr] = holds[1:]
|
|
d.mu.Unlock()
|
|
|
|
logger.Infof("Hold %p: Intercepted connection attempt to addr %q", hold, addr)
|
|
close(hold.waitCh)
|
|
select {
|
|
case err := <-hold.blockCh:
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
|
|
case <-ctx.Done():
|
|
logger.Infof("Hold %p: Connection attempt to addr %q timed out", hold, addr)
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
// Hold is a handle to a single connection attempt. It can be used to block,
|
|
// fail and succeed connection attempts.
|
|
type Hold struct {
|
|
// dialer is the dialer that created this hold.
|
|
dialer *BlockingDialer
|
|
// waitCh is closed when a connection attempt is received.
|
|
waitCh chan struct{}
|
|
// blockCh receives the value to return from DialContext for this connection
|
|
// attempt (nil on resume, an error on fail). It receives at most 1 value.
|
|
blockCh chan error
|
|
// addr is the address that this hold is for.
|
|
addr string
|
|
}
|
|
|
|
// Hold blocks the dialer when a connection attempt is made to the given addr.
|
|
// A hold is valid for exactly one connection attempt. Multiple holds for an
|
|
// addr can be added, and they will apply in the order that the connections are
|
|
// attempted.
|
|
func (d *BlockingDialer) Hold(addr string) *Hold {
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
|
|
h := Hold{dialer: d, blockCh: make(chan error, 1), waitCh: make(chan struct{}), addr: addr}
|
|
d.holds[addr] = append(d.holds[addr], &h)
|
|
return &h
|
|
}
|
|
|
|
// Wait blocks until there is a connection attempt on this Hold, or the context
|
|
// expires. Return false if the context has expired, true otherwise.
|
|
func (h *Hold) Wait(ctx context.Context) bool {
|
|
logger.Infof("Hold %p: Waiting for a connection attempt to addr %q", h, h.addr)
|
|
select {
|
|
case <-ctx.Done():
|
|
return false
|
|
case <-h.waitCh:
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Resume unblocks the dialer for the given addr. Either Resume or Fail must be
|
|
// called at most once on a hold. Otherwise, Resume panics.
|
|
func (h *Hold) Resume() {
|
|
logger.Infof("Hold %p: Resuming connection attempt to addr %q", h, h.addr)
|
|
h.blockCh <- nil
|
|
close(h.blockCh)
|
|
}
|
|
|
|
// Fail fails the connection attempt. Either Resume or Fail must be
|
|
// called at most once on a hold. Otherwise, Resume panics.
|
|
func (h *Hold) Fail(err error) {
|
|
logger.Infof("Hold %p: Failing connection attempt to addr %q", h, h.addr)
|
|
h.blockCh <- err
|
|
close(h.blockCh)
|
|
}
|
|
|
|
// IsStarted returns true if this hold has received a connection attempt.
|
|
func (h *Hold) IsStarted() bool {
|
|
select {
|
|
case <-h.waitCh:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|