mirror of https://github.com/grpc/grpc-go.git
				
				
				
			xds: add ConfigSelector to support RouteAction timeouts (#3991)
This commit is contained in:
		
							parent
							
								
									20636e76a9
								
							
						
					
					
						commit
						b88744b832
					
				|  | @ -399,3 +399,29 @@ func runStructTypeAssertion(b *testing.B, fer interface{}) { | |||
| 		b.Fatal("error") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func BenchmarkWaitGroupAddDone(b *testing.B) { | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	b.RunParallel(func(pb *testing.PB) { | ||||
| 		i := 0 | ||||
| 		for ; pb.Next(); i++ { | ||||
| 			wg.Add(1) | ||||
| 		} | ||||
| 		for ; i > 0; i-- { | ||||
| 			wg.Done() | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func BenchmarkRLockUnlock(b *testing.B) { | ||||
| 	mu := sync.RWMutex{} | ||||
| 	b.RunParallel(func(pb *testing.PB) { | ||||
| 		i := 0 | ||||
| 		for ; pb.Next(); i++ { | ||||
| 			mu.RLock() | ||||
| 		} | ||||
| 		for ; i > 0; i-- { | ||||
| 			mu.RUnlock() | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,114 @@ | |||
| /* | ||||
|  * | ||||
|  * Copyright 2017 gRPC authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  * | ||||
|  */ | ||||
| 
 | ||||
| // Benchmark options for safe config selector type.
 | ||||
| 
 | ||||
| package primitives_test | ||||
| 
 | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 	"unsafe" | ||||
| ) | ||||
| 
 | ||||
| type safeUpdaterAtomicAndCounter struct { | ||||
| 	ptr unsafe.Pointer // *countingFunc
 | ||||
| } | ||||
| 
 | ||||
| type countingFunc struct { | ||||
| 	mu sync.RWMutex | ||||
| 	f  func() | ||||
| } | ||||
| 
 | ||||
| func (s *safeUpdaterAtomicAndCounter) call() { | ||||
| 	cfPtr := atomic.LoadPointer(&s.ptr) | ||||
| 	var cf *countingFunc | ||||
| 	for { | ||||
| 		cf = (*countingFunc)(cfPtr) | ||||
| 		cf.mu.RLock() | ||||
| 		cfPtr2 := atomic.LoadPointer(&s.ptr) | ||||
| 		if cfPtr == cfPtr2 { | ||||
| 			// Use cf with confidence!
 | ||||
| 			break | ||||
| 		} | ||||
| 		// cf changed; try to use the new one instead, because the old one is
 | ||||
| 		// no longer valid to use.
 | ||||
| 		cf.mu.RUnlock() | ||||
| 		cfPtr = cfPtr2 | ||||
| 	} | ||||
| 	defer cf.mu.RUnlock() | ||||
| 	cf.f() | ||||
| } | ||||
| 
 | ||||
| func (s *safeUpdaterAtomicAndCounter) update(f func()) { | ||||
| 	newCF := &countingFunc{f: f} | ||||
| 	oldCFPtr := atomic.SwapPointer(&s.ptr, unsafe.Pointer(newCF)) | ||||
| 	if oldCFPtr == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	(*countingFunc)(oldCFPtr).mu.Lock() | ||||
| 	(*countingFunc)(oldCFPtr).mu.Unlock() //lint:ignore SA2001 necessary to unlock after locking to unblock any RLocks
 | ||||
| } | ||||
| 
 | ||||
| type safeUpdaterRWMutex struct { | ||||
| 	mu sync.RWMutex | ||||
| 	f  func() | ||||
| } | ||||
| 
 | ||||
| func (s *safeUpdaterRWMutex) call() { | ||||
| 	s.mu.RLock() | ||||
| 	defer s.mu.RUnlock() | ||||
| 	s.f() | ||||
| } | ||||
| 
 | ||||
| func (s *safeUpdaterRWMutex) update(f func()) { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	s.f = f | ||||
| } | ||||
| 
 | ||||
| type updater interface { | ||||
| 	call() | ||||
| 	update(f func()) | ||||
| } | ||||
| 
 | ||||
| func benchmarkSafeUpdater(b *testing.B, u updater) { | ||||
| 	t := time.NewTicker(time.Second) | ||||
| 	go func() { | ||||
| 		for range t.C { | ||||
| 			u.update(func() {}) | ||||
| 		} | ||||
| 	}() | ||||
| 	b.RunParallel(func(pb *testing.PB) { | ||||
| 		u.update(func() {}) | ||||
| 		for pb.Next() { | ||||
| 			u.call() | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Stop() | ||||
| } | ||||
| 
 | ||||
| func BenchmarkSafeUpdaterAtomicAndCounter(b *testing.B) { | ||||
| 	benchmarkSafeUpdater(b, &safeUpdaterAtomicAndCounter{}) | ||||
| } | ||||
| 
 | ||||
| func BenchmarkSafeUpdaterRWMutex(b *testing.B) { | ||||
| 	benchmarkSafeUpdater(b, &safeUpdaterRWMutex{}) | ||||
| } | ||||
|  | @ -38,6 +38,7 @@ import ( | |||
| 	"google.golang.org/grpc/internal/channelz" | ||||
| 	"google.golang.org/grpc/internal/grpcsync" | ||||
| 	"google.golang.org/grpc/internal/grpcutil" | ||||
| 	iresolver "google.golang.org/grpc/internal/resolver" | ||||
| 	"google.golang.org/grpc/internal/transport" | ||||
| 	"google.golang.org/grpc/keepalive" | ||||
| 	"google.golang.org/grpc/resolver" | ||||
|  | @ -104,6 +105,17 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { | |||
| 	return DialContext(context.Background(), target, opts...) | ||||
| } | ||||
| 
 | ||||
| type defaultConfigSelector struct { | ||||
| 	sc *ServiceConfig | ||||
| } | ||||
| 
 | ||||
| func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) *iresolver.RPCConfig { | ||||
| 	return &iresolver.RPCConfig{ | ||||
| 		Context:      rpcInfo.Context, | ||||
| 		MethodConfig: getMethodConfig(dcs.sc, rpcInfo.Method), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // DialContext creates a client connection to the given target. By default, it's
 | ||||
| // a non-blocking dial (the function won't wait for connections to be
 | ||||
| // established, and connecting happens in the background). To make it a blocking
 | ||||
|  | @ -224,6 +236,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * | |||
| 		case sc, ok := <-cc.dopts.scChan: | ||||
| 			if ok { | ||||
| 				cc.sc = &sc | ||||
| 				cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc}) | ||||
| 				scSet = true | ||||
| 			} | ||||
| 		default: | ||||
|  | @ -273,6 +286,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * | |||
| 		case sc, ok := <-cc.dopts.scChan: | ||||
| 			if ok { | ||||
| 				cc.sc = &sc | ||||
| 				cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc}) | ||||
| 			} | ||||
| 		case <-ctx.Done(): | ||||
| 			return nil, ctx.Err() | ||||
|  | @ -479,6 +493,8 @@ type ClientConn struct { | |||
| 	balancerBuildOpts balancer.BuildOptions | ||||
| 	blockingpicker    *pickerWrapper | ||||
| 
 | ||||
| 	safeConfigSelector iresolver.SafeConfigSelector | ||||
| 
 | ||||
| 	mu              sync.RWMutex | ||||
| 	resolverWrapper *ccResolverWrapper | ||||
| 	sc              *ServiceConfig | ||||
|  | @ -539,6 +555,7 @@ func (cc *ClientConn) scWatcher() { | |||
| 			// TODO: load balance policy runtime change is ignored.
 | ||||
| 			// We may revisit this decision in the future.
 | ||||
| 			cc.sc = &sc | ||||
| 			cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc}) | ||||
| 			cc.mu.Unlock() | ||||
| 		case <-cc.ctx.Done(): | ||||
| 			return | ||||
|  | @ -577,13 +594,13 @@ func init() { | |||
| 
 | ||||
| func (cc *ClientConn) maybeApplyDefaultServiceConfig(addrs []resolver.Address) { | ||||
| 	if cc.sc != nil { | ||||
| 		cc.applyServiceConfigAndBalancer(cc.sc, addrs) | ||||
| 		cc.applyServiceConfigAndBalancer(cc.sc, nil, addrs) | ||||
| 		return | ||||
| 	} | ||||
| 	if cc.dopts.defaultServiceConfig != nil { | ||||
| 		cc.applyServiceConfigAndBalancer(cc.dopts.defaultServiceConfig, addrs) | ||||
| 		cc.applyServiceConfigAndBalancer(cc.dopts.defaultServiceConfig, &defaultConfigSelector{cc.dopts.defaultServiceConfig}, addrs) | ||||
| 	} else { | ||||
| 		cc.applyServiceConfigAndBalancer(emptyServiceConfig, addrs) | ||||
| 		cc.applyServiceConfigAndBalancer(emptyServiceConfig, &defaultConfigSelector{emptyServiceConfig}, addrs) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -620,7 +637,15 @@ func (cc *ClientConn) updateResolverState(s resolver.State, err error) error { | |||
| 		// default, per the error handling design?
 | ||||
| 	} else { | ||||
| 		if sc, ok := s.ServiceConfig.Config.(*ServiceConfig); s.ServiceConfig.Err == nil && ok { | ||||
| 			cc.applyServiceConfigAndBalancer(sc, s.Addresses) | ||||
| 			configSelector := iresolver.GetConfigSelector(s) | ||||
| 			if configSelector != nil { | ||||
| 				if len(s.ServiceConfig.Config.(*ServiceConfig).Methods) != 0 { | ||||
| 					channelz.Infof(logger, cc.channelzID, "method configs in service config will be ignored due to presence of config selector") | ||||
| 				} | ||||
| 			} else { | ||||
| 				configSelector = &defaultConfigSelector{sc} | ||||
| 			} | ||||
| 			cc.applyServiceConfigAndBalancer(sc, configSelector, s.Addresses) | ||||
| 		} else { | ||||
| 			ret = balancer.ErrBadResolverState | ||||
| 			if cc.balancerWrapper == nil { | ||||
|  | @ -630,6 +655,7 @@ func (cc *ClientConn) updateResolverState(s resolver.State, err error) error { | |||
| 				} else { | ||||
| 					err = status.Errorf(codes.Unavailable, "illegal service config type: %T", s.ServiceConfig.Config) | ||||
| 				} | ||||
| 				cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{cc.sc}) | ||||
| 				cc.blockingpicker.updatePicker(base.NewErrPicker(err)) | ||||
| 				cc.csMgr.updateState(connectivity.TransientFailure) | ||||
| 				cc.mu.Unlock() | ||||
|  | @ -864,6 +890,20 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { | |||
| 	return curAddrFound | ||||
| } | ||||
| 
 | ||||
| func getMethodConfig(sc *ServiceConfig, method string) MethodConfig { | ||||
| 	if sc == nil { | ||||
| 		return MethodConfig{} | ||||
| 	} | ||||
| 	if m, ok := sc.Methods[method]; ok { | ||||
| 		return m | ||||
| 	} | ||||
| 	i := strings.LastIndex(method, "/") | ||||
| 	if m, ok := sc.Methods[method[:i+1]]; ok { | ||||
| 		return m | ||||
| 	} | ||||
| 	return sc.Methods[""] | ||||
| } | ||||
| 
 | ||||
| // GetMethodConfig gets the method config of the input method.
 | ||||
| // If there's an exact match for input method (i.e. /service/method), we return
 | ||||
| // the corresponding MethodConfig.
 | ||||
|  | @ -876,17 +916,7 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { | |||
| 	// TODO: Avoid the locking here.
 | ||||
| 	cc.mu.RLock() | ||||
| 	defer cc.mu.RUnlock() | ||||
| 	if cc.sc == nil { | ||||
| 		return MethodConfig{} | ||||
| 	} | ||||
| 	if m, ok := cc.sc.Methods[method]; ok { | ||||
| 		return m | ||||
| 	} | ||||
| 	i := strings.LastIndex(method, "/") | ||||
| 	if m, ok := cc.sc.Methods[method[:i+1]]; ok { | ||||
| 		return m | ||||
| 	} | ||||
| 	return cc.sc.Methods[""] | ||||
| 	return getMethodConfig(cc.sc, method) | ||||
| } | ||||
| 
 | ||||
| func (cc *ClientConn) healthCheckConfig() *healthCheckConfig { | ||||
|  | @ -909,12 +939,15 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method st | |||
| 	return t, done, nil | ||||
| } | ||||
| 
 | ||||
| func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, addrs []resolver.Address) { | ||||
| func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector, addrs []resolver.Address) { | ||||
| 	if sc == nil { | ||||
| 		// should never reach here.
 | ||||
| 		return | ||||
| 	} | ||||
| 	cc.sc = sc | ||||
| 	if configSelector != nil { | ||||
| 		cc.safeConfigSelector.UpdateConfigSelector(configSelector) | ||||
| 	} | ||||
| 
 | ||||
| 	if cc.sc.retryThrottling != nil { | ||||
| 		newThrottler := &retryThrottler{ | ||||
|  |  | |||
|  | @ -0,0 +1,93 @@ | |||
| /* | ||||
|  * | ||||
|  * Copyright 2020 gRPC authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  * | ||||
|  */ | ||||
| 
 | ||||
| // Package resolver provides internal resolver-related functionality.
 | ||||
| package resolver | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"google.golang.org/grpc/internal/serviceconfig" | ||||
| 	"google.golang.org/grpc/resolver" | ||||
| ) | ||||
| 
 | ||||
| // ConfigSelector controls what configuration to use for every RPC.
 | ||||
| type ConfigSelector interface { | ||||
| 	// Selects the configuration for the RPC.
 | ||||
| 	SelectConfig(RPCInfo) *RPCConfig | ||||
| } | ||||
| 
 | ||||
| // RPCInfo contains RPC information needed by a ConfigSelector.
 | ||||
| type RPCInfo struct { | ||||
| 	// Context is the user's context for the RPC and contains headers and
 | ||||
| 	// application timeout.  It is passed for interception purposes and for
 | ||||
| 	// efficiency reasons.  SelectConfig should not be blocking.
 | ||||
| 	Context context.Context | ||||
| 	Method  string // i.e. "/Service/Method"
 | ||||
| } | ||||
| 
 | ||||
| // RPCConfig describes the configuration to use for each RPC.
 | ||||
| type RPCConfig struct { | ||||
| 	// The context to use for the remainder of the RPC; can pass info to LB
 | ||||
| 	// policy or affect timeout or metadata.
 | ||||
| 	Context      context.Context | ||||
| 	MethodConfig serviceconfig.MethodConfig // configuration to use for this RPC
 | ||||
| 	OnCommitted  func()                     // Called when the RPC has been committed (retries no longer possible)
 | ||||
| } | ||||
| 
 | ||||
| type csKeyType string | ||||
| 
 | ||||
| const csKey = csKeyType("grpc.internal.resolver.configSelector") | ||||
| 
 | ||||
| // SetConfigSelector sets the config selector in state and returns the new
 | ||||
| // state.
 | ||||
| func SetConfigSelector(state resolver.State, cs ConfigSelector) resolver.State { | ||||
| 	state.Attributes = state.Attributes.WithValues(csKey, cs) | ||||
| 	return state | ||||
| } | ||||
| 
 | ||||
| // GetConfigSelector retrieves the config selector from state, if present, and
 | ||||
| // returns it or nil if absent.
 | ||||
| func GetConfigSelector(state resolver.State) ConfigSelector { | ||||
| 	cs, _ := state.Attributes.Value(csKey).(ConfigSelector) | ||||
| 	return cs | ||||
| } | ||||
| 
 | ||||
| // SafeConfigSelector allows for safe switching of ConfigSelector
 | ||||
| // implementations such that previous values are guaranteed to not be in use
 | ||||
| // when UpdateConfigSelector returns.
 | ||||
| type SafeConfigSelector struct { | ||||
| 	mu sync.RWMutex | ||||
| 	cs ConfigSelector | ||||
| } | ||||
| 
 | ||||
| // UpdateConfigSelector swaps to the provided ConfigSelector and blocks until
 | ||||
| // all uses of the previous ConfigSelector have completed.
 | ||||
| func (scs *SafeConfigSelector) UpdateConfigSelector(cs ConfigSelector) { | ||||
| 	scs.mu.Lock() | ||||
| 	defer scs.mu.Unlock() | ||||
| 	scs.cs = cs | ||||
| } | ||||
| 
 | ||||
| // SelectConfig defers to the current ConfigSelector in scs.
 | ||||
| func (scs *SafeConfigSelector) SelectConfig(r RPCInfo) *RPCConfig { | ||||
| 	scs.mu.RLock() | ||||
| 	defer scs.mu.RUnlock() | ||||
| 	return scs.cs.SelectConfig(r) | ||||
| } | ||||
|  | @ -0,0 +1,153 @@ | |||
| /* | ||||
|  * | ||||
|  * Copyright 2020 gRPC authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  * | ||||
|  */ | ||||
| 
 | ||||
| package resolver | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"google.golang.org/grpc/internal/grpctest" | ||||
| 	"google.golang.org/grpc/internal/serviceconfig" | ||||
| ) | ||||
| 
 | ||||
| type s struct { | ||||
| 	grpctest.Tester | ||||
| } | ||||
| 
 | ||||
| func Test(t *testing.T) { | ||||
| 	grpctest.RunSubTests(t, s{}) | ||||
| } | ||||
| 
 | ||||
| type fakeConfigSelector struct { | ||||
| 	selectConfig func(RPCInfo) *RPCConfig | ||||
| } | ||||
| 
 | ||||
| func (f *fakeConfigSelector) SelectConfig(r RPCInfo) *RPCConfig { | ||||
| 	return f.selectConfig(r) | ||||
| } | ||||
| 
 | ||||
| func (s) TestSafeConfigSelector(t *testing.T) { | ||||
| 	testRPCInfo := RPCInfo{Method: "test method"} | ||||
| 
 | ||||
| 	retChan1 := make(chan *RPCConfig) | ||||
| 	retChan2 := make(chan *RPCConfig) | ||||
| 
 | ||||
| 	one := 1 | ||||
| 	two := 2 | ||||
| 
 | ||||
| 	resp1 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &one}} | ||||
| 	resp2 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &two}} | ||||
| 
 | ||||
| 	cs1Called := make(chan struct{}) | ||||
| 	cs2Called := make(chan struct{}) | ||||
| 
 | ||||
| 	cs1 := &fakeConfigSelector{ | ||||
| 		selectConfig: func(r RPCInfo) *RPCConfig { | ||||
| 			cs1Called <- struct{}{} | ||||
| 			if diff := cmp.Diff(r, testRPCInfo); diff != "" { | ||||
| 				t.Errorf("SelectConfig(%v) called; want %v\n  Diffs:\n%s", r, testRPCInfo, diff) | ||||
| 			} | ||||
| 			return <-retChan1 | ||||
| 		}, | ||||
| 	} | ||||
| 	cs2 := &fakeConfigSelector{ | ||||
| 		selectConfig: func(r RPCInfo) *RPCConfig { | ||||
| 			cs2Called <- struct{}{} | ||||
| 			if diff := cmp.Diff(r, testRPCInfo); diff != "" { | ||||
| 				t.Errorf("SelectConfig(%v) called; want %v\n  Diffs:\n%s", r, testRPCInfo, diff) | ||||
| 			} | ||||
| 			return <-retChan2 | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	scs := &SafeConfigSelector{} | ||||
| 	scs.UpdateConfigSelector(cs1) | ||||
| 
 | ||||
| 	cs1Returned := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		got := scs.SelectConfig(testRPCInfo) // blocks until send to retChan1
 | ||||
| 		if got != resp1 { | ||||
| 			t.Errorf("SelectConfig(%v) = %v; want %v", testRPCInfo, got, resp1) | ||||
| 		} | ||||
| 		close(cs1Returned) | ||||
| 	}() | ||||
| 
 | ||||
| 	// cs1 is blocked but should be called
 | ||||
| 	select { | ||||
| 	case <-time.After(500 * time.Millisecond): | ||||
| 		t.Fatalf("timed out waiting for cs1 to be called") | ||||
| 	case <-cs1Called: | ||||
| 	} | ||||
| 
 | ||||
| 	// swap in cs2 now that cs1 is called
 | ||||
| 	csSwapped := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		// wait awhile first to ensure cs1 could be called below.
 | ||||
| 		time.Sleep(50 * time.Millisecond) | ||||
| 		scs.UpdateConfigSelector(cs2) // Blocks until cs1 done
 | ||||
| 		close(csSwapped) | ||||
| 	}() | ||||
| 
 | ||||
| 	// Allow cs1 to return and cs2 to eventually be swapped in.
 | ||||
| 	retChan1 <- resp1 | ||||
| 
 | ||||
| 	cs1Done := false // set when cs2 is first called
 | ||||
| 	for dl := time.Now().Add(150 * time.Millisecond); !time.Now().After(dl); { | ||||
| 		gotConfigChan := make(chan *RPCConfig) | ||||
| 		go func() { | ||||
| 			gotConfigChan <- scs.SelectConfig(testRPCInfo) | ||||
| 		}() | ||||
| 		select { | ||||
| 		case <-time.After(500 * time.Millisecond): | ||||
| 			t.Fatalf("timed out waiting for cs1 or cs2 to be called") | ||||
| 		case <-cs1Called: | ||||
| 			// Initially, before swapping to cs2, cs1 should be called
 | ||||
| 			retChan1 <- resp1 | ||||
| 			go func() { <-gotConfigChan }() | ||||
| 			if cs1Done { | ||||
| 				t.Fatalf("cs1 called after cs2") | ||||
| 			} | ||||
| 		case <-cs2Called: | ||||
| 			// Success! the new config selector is being called
 | ||||
| 			if !cs1Done { | ||||
| 				select { | ||||
| 				case <-csSwapped: | ||||
| 				case <-time.After(50 * time.Millisecond): | ||||
| 					t.Fatalf("timed out waiting for UpdateConfigSelector to return") | ||||
| 				} | ||||
| 				select { | ||||
| 				case <-cs1Returned: | ||||
| 				case <-time.After(50 * time.Millisecond): | ||||
| 					t.Fatalf("timed out waiting for cs1 to return") | ||||
| 				} | ||||
| 				cs1Done = true | ||||
| 			} | ||||
| 			retChan2 <- resp2 | ||||
| 			got := <-gotConfigChan | ||||
| 			if diff := cmp.Diff(got, resp2); diff != "" { | ||||
| 				t.Fatalf("SelectConfig(%v) = %v; want %v\n  Diffs:\n%s", testRPCInfo, got, resp2, diff) | ||||
| 			} | ||||
| 		} | ||||
| 		time.Sleep(10 * time.Millisecond) | ||||
| 	} | ||||
| 	if !cs1Done { | ||||
| 		t.Fatalf("timed out waiting for cs2 to be called") | ||||
| 	} | ||||
| } | ||||
|  | @ -22,8 +22,10 @@ package serviceconfig | |||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"google.golang.org/grpc/balancer" | ||||
| 	"google.golang.org/grpc/codes" | ||||
| 	"google.golang.org/grpc/grpclog" | ||||
| 	externalserviceconfig "google.golang.org/grpc/serviceconfig" | ||||
| ) | ||||
|  | @ -104,3 +106,57 @@ func (bc *BalancerConfig) UnmarshalJSON(b []byte) error { | |||
| 	// case.
 | ||||
| 	return fmt.Errorf("invalid loadBalancingConfig: no supported policies found") | ||||
| } | ||||
| 
 | ||||
| // MethodConfig defines the configuration recommended by the service providers for a
 | ||||
| // particular method.
 | ||||
| type MethodConfig struct { | ||||
| 	// WaitForReady indicates whether RPCs sent to this method should wait until
 | ||||
| 	// the connection is ready by default (!failfast). The value specified via the
 | ||||
| 	// gRPC client API will override the value set here.
 | ||||
| 	WaitForReady *bool | ||||
| 	// Timeout is the default timeout for RPCs sent to this method. The actual
 | ||||
| 	// deadline used will be the minimum of the value specified here and the value
 | ||||
| 	// set by the application via the gRPC client API.  If either one is not set,
 | ||||
| 	// then the other will be used.  If neither is set, then the RPC has no deadline.
 | ||||
| 	Timeout *time.Duration | ||||
| 	// MaxReqSize is the maximum allowed payload size for an individual request in a
 | ||||
| 	// stream (client->server) in bytes. The size which is measured is the serialized
 | ||||
| 	// payload after per-message compression (but before stream compression) in bytes.
 | ||||
| 	// The actual value used is the minimum of the value specified here and the value set
 | ||||
| 	// by the application via the gRPC client API. If either one is not set, then the other
 | ||||
| 	// will be used.  If neither is set, then the built-in default is used.
 | ||||
| 	MaxReqSize *int | ||||
| 	// MaxRespSize is the maximum allowed payload size for an individual response in a
 | ||||
| 	// stream (server->client) in bytes.
 | ||||
| 	MaxRespSize *int | ||||
| 	// RetryPolicy configures retry options for the method.
 | ||||
| 	RetryPolicy *RetryPolicy | ||||
| } | ||||
| 
 | ||||
| // RetryPolicy defines the go-native version of the retry policy defined by the
 | ||||
| // service config here:
 | ||||
| // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#integration-with-service-config
 | ||||
| type RetryPolicy struct { | ||||
| 	// MaxAttempts is the maximum number of attempts, including the original RPC.
 | ||||
| 	//
 | ||||
| 	// This field is required and must be two or greater.
 | ||||
| 	MaxAttempts int | ||||
| 
 | ||||
| 	// Exponential backoff parameters. The initial retry attempt will occur at
 | ||||
| 	// random(0, initialBackoff). In general, the nth attempt will occur at
 | ||||
| 	// random(0,
 | ||||
| 	//   min(initialBackoff*backoffMultiplier**(n-1), maxBackoff)).
 | ||||
| 	//
 | ||||
| 	// These fields are required and must be greater than zero.
 | ||||
| 	InitialBackoff    time.Duration | ||||
| 	MaxBackoff        time.Duration | ||||
| 	BackoffMultiplier float64 | ||||
| 
 | ||||
| 	// The set of status codes which may be retried.
 | ||||
| 	//
 | ||||
| 	// Status codes are specified as strings, e.g., "UNAVAILABLE".
 | ||||
| 	//
 | ||||
| 	// This field is required and must be non-empty.
 | ||||
| 	// Note: a set is used to store this for easy lookup.
 | ||||
| 	RetryableStatusCodes map[codes.Code]bool | ||||
| } | ||||
|  |  | |||
|  | @ -30,17 +30,50 @@ type Channel struct { | |||
| } | ||||
| 
 | ||||
| // Send sends value on the underlying channel.
 | ||||
| func (cwt *Channel) Send(value interface{}) { | ||||
| 	cwt.ch <- value | ||||
| func (c *Channel) Send(value interface{}) { | ||||
| 	c.ch <- value | ||||
| } | ||||
| 
 | ||||
| // SendContext sends value on the underlying channel, or returns an error if
 | ||||
| // the context expires.
 | ||||
| func (c *Channel) SendContext(ctx context.Context, value interface{}) error { | ||||
| 	select { | ||||
| 	case c.ch <- value: | ||||
| 		return nil | ||||
| 	case <-ctx.Done(): | ||||
| 		return ctx.Err() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SendOrFail attempts to send value on the underlying channel.  Returns true
 | ||||
| // if successful or false if the channel was full.
 | ||||
| func (c *Channel) SendOrFail(value interface{}) bool { | ||||
| 	select { | ||||
| 	case c.ch <- value: | ||||
| 		return true | ||||
| 	default: | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // ReceiveOrFail returns the value on the underlying channel and true, or nil
 | ||||
| // and false if the channel was empty.
 | ||||
| func (c *Channel) ReceiveOrFail() (interface{}, bool) { | ||||
| 	select { | ||||
| 	case got := <-c.ch: | ||||
| 		return got, true | ||||
| 	default: | ||||
| 		return nil, false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Receive returns the value received on the underlying channel, or the error
 | ||||
| // returned by ctx if it is closed or cancelled.
 | ||||
| func (cwt *Channel) Receive(ctx context.Context) (interface{}, error) { | ||||
| func (c *Channel) Receive(ctx context.Context) (interface{}, error) { | ||||
| 	select { | ||||
| 	case <-ctx.Done(): | ||||
| 		return nil, ctx.Err() | ||||
| 	case got := <-cwt.ch: | ||||
| 	case got := <-c.ch: | ||||
| 		return got, nil | ||||
| 	} | ||||
| } | ||||
|  | @ -50,12 +83,12 @@ func (cwt *Channel) Receive(ctx context.Context) (interface{}, error) { | |||
| // It's expected to be used with a size-1 channel, to only keep the most
 | ||||
| // up-to-date item. This method is inherently racy when invoked concurrently
 | ||||
| // from multiple goroutines.
 | ||||
| func (cwt *Channel) Replace(value interface{}) { | ||||
| func (c *Channel) Replace(value interface{}) { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case cwt.ch <- value: | ||||
| 		case c.ch <- value: | ||||
| 			return | ||||
| 		case <-cwt.ch: | ||||
| 		case <-c.ch: | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -41,29 +41,7 @@ const maxInt = int(^uint(0) >> 1) | |||
| // Deprecated: Users should not use this struct. Service config should be received
 | ||||
| // through name resolver, as specified here
 | ||||
| // https://github.com/grpc/grpc/blob/master/doc/service_config.md
 | ||||
| type MethodConfig struct { | ||||
| 	// WaitForReady indicates whether RPCs sent to this method should wait until
 | ||||
| 	// the connection is ready by default (!failfast). The value specified via the
 | ||||
| 	// gRPC client API will override the value set here.
 | ||||
| 	WaitForReady *bool | ||||
| 	// Timeout is the default timeout for RPCs sent to this method. The actual
 | ||||
| 	// deadline used will be the minimum of the value specified here and the value
 | ||||
| 	// set by the application via the gRPC client API.  If either one is not set,
 | ||||
| 	// then the other will be used.  If neither is set, then the RPC has no deadline.
 | ||||
| 	Timeout *time.Duration | ||||
| 	// MaxReqSize is the maximum allowed payload size for an individual request in a
 | ||||
| 	// stream (client->server) in bytes. The size which is measured is the serialized
 | ||||
| 	// payload after per-message compression (but before stream compression) in bytes.
 | ||||
| 	// The actual value used is the minimum of the value specified here and the value set
 | ||||
| 	// by the application via the gRPC client API. If either one is not set, then the other
 | ||||
| 	// will be used.  If neither is set, then the built-in default is used.
 | ||||
| 	MaxReqSize *int | ||||
| 	// MaxRespSize is the maximum allowed payload size for an individual response in a
 | ||||
| 	// stream (server->client) in bytes.
 | ||||
| 	MaxRespSize *int | ||||
| 	// RetryPolicy configures retry options for the method.
 | ||||
| 	retryPolicy *retryPolicy | ||||
| } | ||||
| type MethodConfig = internalserviceconfig.MethodConfig | ||||
| 
 | ||||
| type lbConfig struct { | ||||
| 	name string | ||||
|  | @ -127,34 +105,6 @@ type healthCheckConfig struct { | |||
| 	ServiceName string | ||||
| } | ||||
| 
 | ||||
| // retryPolicy defines the go-native version of the retry policy defined by the
 | ||||
| // service config here:
 | ||||
| // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#integration-with-service-config
 | ||||
| type retryPolicy struct { | ||||
| 	// MaxAttempts is the maximum number of attempts, including the original RPC.
 | ||||
| 	//
 | ||||
| 	// This field is required and must be two or greater.
 | ||||
| 	maxAttempts int | ||||
| 
 | ||||
| 	// Exponential backoff parameters. The initial retry attempt will occur at
 | ||||
| 	// random(0, initialBackoff). In general, the nth attempt will occur at
 | ||||
| 	// random(0,
 | ||||
| 	//   min(initialBackoff*backoffMultiplier**(n-1), maxBackoff)).
 | ||||
| 	//
 | ||||
| 	// These fields are required and must be greater than zero.
 | ||||
| 	initialBackoff    time.Duration | ||||
| 	maxBackoff        time.Duration | ||||
| 	backoffMultiplier float64 | ||||
| 
 | ||||
| 	// The set of status codes which may be retried.
 | ||||
| 	//
 | ||||
| 	// Status codes are specified as strings, e.g., "UNAVAILABLE".
 | ||||
| 	//
 | ||||
| 	// This field is required and must be non-empty.
 | ||||
| 	// Note: a set is used to store this for easy lookup.
 | ||||
| 	retryableStatusCodes map[codes.Code]bool | ||||
| } | ||||
| 
 | ||||
| type jsonRetryPolicy struct { | ||||
| 	MaxAttempts          int | ||||
| 	InitialBackoff       string | ||||
|  | @ -313,7 +263,7 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult { | |||
| 			WaitForReady: m.WaitForReady, | ||||
| 			Timeout:      d, | ||||
| 		} | ||||
| 		if mc.retryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil { | ||||
| 		if mc.RetryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil { | ||||
| 			logger.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) | ||||
| 			return &serviceconfig.ParseResult{Err: err} | ||||
| 		} | ||||
|  | @ -359,7 +309,7 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult { | |||
| 	return &serviceconfig.ParseResult{Config: &sc} | ||||
| } | ||||
| 
 | ||||
| func convertRetryPolicy(jrp *jsonRetryPolicy) (p *retryPolicy, err error) { | ||||
| func convertRetryPolicy(jrp *jsonRetryPolicy) (p *internalserviceconfig.RetryPolicy, err error) { | ||||
| 	if jrp == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | @ -381,19 +331,19 @@ func convertRetryPolicy(jrp *jsonRetryPolicy) (p *retryPolicy, err error) { | |||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	rp := &retryPolicy{ | ||||
| 		maxAttempts:          jrp.MaxAttempts, | ||||
| 		initialBackoff:       *ib, | ||||
| 		maxBackoff:           *mb, | ||||
| 		backoffMultiplier:    jrp.BackoffMultiplier, | ||||
| 		retryableStatusCodes: make(map[codes.Code]bool), | ||||
| 	rp := &internalserviceconfig.RetryPolicy{ | ||||
| 		MaxAttempts:          jrp.MaxAttempts, | ||||
| 		InitialBackoff:       *ib, | ||||
| 		MaxBackoff:           *mb, | ||||
| 		BackoffMultiplier:    jrp.BackoffMultiplier, | ||||
| 		RetryableStatusCodes: make(map[codes.Code]bool), | ||||
| 	} | ||||
| 	if rp.maxAttempts > 5 { | ||||
| 	if rp.MaxAttempts > 5 { | ||||
| 		// TODO(retry): Make the max maxAttempts configurable.
 | ||||
| 		rp.maxAttempts = 5 | ||||
| 		rp.MaxAttempts = 5 | ||||
| 	} | ||||
| 	for _, code := range jrp.RetryableStatusCodes { | ||||
| 		rp.retryableStatusCodes[code] = true | ||||
| 		rp.RetryableStatusCodes[code] = true | ||||
| 	} | ||||
| 	return rp, nil | ||||
| } | ||||
|  |  | |||
							
								
								
									
										34
									
								
								stream.go
								
								
								
								
							
							
						
						
									
										34
									
								
								stream.go
								
								
								
								
							|  | @ -36,6 +36,8 @@ import ( | |||
| 	"google.golang.org/grpc/internal/channelz" | ||||
| 	"google.golang.org/grpc/internal/grpcrand" | ||||
| 	"google.golang.org/grpc/internal/grpcutil" | ||||
| 	iresolver "google.golang.org/grpc/internal/resolver" | ||||
| 	"google.golang.org/grpc/internal/serviceconfig" | ||||
| 	"google.golang.org/grpc/internal/transport" | ||||
| 	"google.golang.org/grpc/metadata" | ||||
| 	"google.golang.org/grpc/peer" | ||||
|  | @ -170,7 +172,18 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth | |||
| 	if err := cc.waitForResolvedAddrs(ctx); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	mc := cc.GetMethodConfig(method) | ||||
| 
 | ||||
| 	var mc serviceconfig.MethodConfig | ||||
| 	var onCommit func() | ||||
| 	rpcConfig := cc.safeConfigSelector.SelectConfig(iresolver.RPCInfo{Context: ctx, Method: method}) | ||||
| 	if rpcConfig != nil { | ||||
| 		if rpcConfig.Context != nil { | ||||
| 			ctx = rpcConfig.Context | ||||
| 		} | ||||
| 		mc = rpcConfig.MethodConfig | ||||
| 		onCommit = rpcConfig.OnCommitted | ||||
| 	} | ||||
| 
 | ||||
| 	if mc.WaitForReady != nil { | ||||
| 		c.failFast = !*mc.WaitForReady | ||||
| 	} | ||||
|  | @ -272,6 +285,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth | |||
| 		cancel:       cancel, | ||||
| 		beginTime:    beginTime, | ||||
| 		firstAttempt: true, | ||||
| 		onCommit:     onCommit, | ||||
| 	} | ||||
| 	if !cc.dopts.disableRetry { | ||||
| 		cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler) | ||||
|  | @ -432,7 +446,8 @@ type clientStream struct { | |||
| 	// place where we need to check if the attempt is nil.
 | ||||
| 	attempt *csAttempt | ||||
| 	// TODO(hedging): hedging will have multiple attempts simultaneously.
 | ||||
| 	committed  bool                       // active attempt committed for retry?
 | ||||
| 	committed  bool // active attempt committed for retry?
 | ||||
| 	onCommit   func() | ||||
| 	buffer     []func(a *csAttempt) error // operations to replay on retry
 | ||||
| 	bufferSize int                        // current size of buffer
 | ||||
| } | ||||
|  | @ -461,6 +476,9 @@ type csAttempt struct { | |||
| } | ||||
| 
 | ||||
| func (cs *clientStream) commitAttemptLocked() { | ||||
| 	if !cs.committed && cs.onCommit != nil { | ||||
| 		cs.onCommit() | ||||
| 	} | ||||
| 	cs.committed = true | ||||
| 	cs.buffer = nil | ||||
| } | ||||
|  | @ -539,8 +557,8 @@ func (cs *clientStream) shouldRetry(err error) error { | |||
| 		code = status.Convert(err).Code() | ||||
| 	} | ||||
| 
 | ||||
| 	rp := cs.methodConfig.retryPolicy | ||||
| 	if rp == nil || !rp.retryableStatusCodes[code] { | ||||
| 	rp := cs.methodConfig.RetryPolicy | ||||
| 	if rp == nil || !rp.RetryableStatusCodes[code] { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
|  | @ -549,7 +567,7 @@ func (cs *clientStream) shouldRetry(err error) error { | |||
| 	if cs.retryThrottler.throttle() { | ||||
| 		return err | ||||
| 	} | ||||
| 	if cs.numRetries+1 >= rp.maxAttempts { | ||||
| 	if cs.numRetries+1 >= rp.MaxAttempts { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
|  | @ -558,9 +576,9 @@ func (cs *clientStream) shouldRetry(err error) error { | |||
| 		dur = time.Millisecond * time.Duration(pushback) | ||||
| 		cs.numRetriesSincePushback = 0 | ||||
| 	} else { | ||||
| 		fact := math.Pow(rp.backoffMultiplier, float64(cs.numRetriesSincePushback)) | ||||
| 		cur := float64(rp.initialBackoff) * fact | ||||
| 		if max := float64(rp.maxBackoff); cur > max { | ||||
| 		fact := math.Pow(rp.BackoffMultiplier, float64(cs.numRetriesSincePushback)) | ||||
| 		cur := float64(rp.InitialBackoff) * fact | ||||
| 		if max := float64(rp.MaxBackoff); cur > max { | ||||
| 			cur = max | ||||
| 		} | ||||
| 		dur = time.Duration(grpcrand.Int63n(int64(cur))) | ||||
|  |  | |||
|  | @ -5784,10 +5784,9 @@ func testGetMethodConfigTD(t *testing.T, e env) { | |||
| 	ch <- sc | ||||
| 	// Wait for the new service config to propagate.
 | ||||
| 	for { | ||||
| 		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.DeadlineExceeded { | ||||
| 			continue | ||||
| 		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded { | ||||
| 			break | ||||
| 		} | ||||
| 		break | ||||
| 	} | ||||
| 	// The following RPCs are expected to become fail-fast.
 | ||||
| 	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable { | ||||
|  |  | |||
|  | @ -0,0 +1,203 @@ | |||
| /* | ||||
|  * | ||||
|  * Copyright 2020 gRPC authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  * | ||||
|  */ | ||||
| 
 | ||||
| package test | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"github.com/google/go-cmp/cmp/cmpopts" | ||||
| 	iresolver "google.golang.org/grpc/internal/resolver" | ||||
| 	"google.golang.org/grpc/internal/serviceconfig" | ||||
| 	"google.golang.org/grpc/internal/testutils" | ||||
| 	"google.golang.org/grpc/metadata" | ||||
| 	"google.golang.org/grpc/resolver" | ||||
| 	"google.golang.org/grpc/resolver/manual" | ||||
| 	testpb "google.golang.org/grpc/test/grpc_testing" | ||||
| ) | ||||
| 
 | ||||
| type funcConfigSelector struct { | ||||
| 	f func(iresolver.RPCInfo) *iresolver.RPCConfig | ||||
| } | ||||
| 
 | ||||
| func (f funcConfigSelector) SelectConfig(i iresolver.RPCInfo) *iresolver.RPCConfig { | ||||
| 	return f.f(i) | ||||
| } | ||||
| 
 | ||||
| func (s) TestConfigSelector(t *testing.T) { | ||||
| 	gotContextChan := testutils.NewChannelWithSize(1) | ||||
| 
 | ||||
| 	ss := &stubServer{ | ||||
| 		emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { | ||||
| 			gotContextChan.SendContext(ctx, ctx) | ||||
| 			return &testpb.Empty{}, nil | ||||
| 		}, | ||||
| 	} | ||||
| 	ss.r = manual.NewBuilderWithScheme("confSel") | ||||
| 
 | ||||
| 	if err := ss.Start(nil); err != nil { | ||||
| 		t.Fatalf("Error starting endpoint server: %v", err) | ||||
| 	} | ||||
| 	defer ss.Stop() | ||||
| 
 | ||||
| 	ctxDeadline := time.Now().Add(10 * time.Second) | ||||
| 	ctx, cancel := context.WithDeadline(context.Background(), ctxDeadline) | ||||
| 	defer cancel() | ||||
| 
 | ||||
| 	longCtxDeadline := time.Now().Add(30 * time.Second) | ||||
| 	longdeadlineCtx, cancel := context.WithDeadline(context.Background(), longCtxDeadline) | ||||
| 	defer cancel() | ||||
| 	shorterTimeout := 3 * time.Second | ||||
| 
 | ||||
| 	testMD := metadata.MD{"footest": []string{"bazbar"}} | ||||
| 	mdOut := metadata.MD{"handler": []string{"value"}} | ||||
| 
 | ||||
| 	var onCommittedCalled bool | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		name   string | ||||
| 		md     metadata.MD | ||||
| 		config *iresolver.RPCConfig | ||||
| 
 | ||||
| 		wantMD       metadata.MD | ||||
| 		wantDeadline time.Time | ||||
| 		wantTimeout  time.Duration | ||||
| 	}{{ | ||||
| 		name:         "basic", | ||||
| 		md:           testMD, | ||||
| 		config:       &iresolver.RPCConfig{}, | ||||
| 		wantMD:       testMD, | ||||
| 		wantDeadline: ctxDeadline, | ||||
| 	}, { | ||||
| 		name: "alter MD", | ||||
| 		md:   testMD, | ||||
| 		config: &iresolver.RPCConfig{ | ||||
| 			Context: metadata.NewOutgoingContext(ctx, mdOut), | ||||
| 		}, | ||||
| 		wantMD:       mdOut, | ||||
| 		wantDeadline: ctxDeadline, | ||||
| 	}, { | ||||
| 		name: "alter timeout; remove MD", | ||||
| 		md:   testMD, | ||||
| 		config: &iresolver.RPCConfig{ | ||||
| 			Context: longdeadlineCtx, // no metadata
 | ||||
| 		}, | ||||
| 		wantMD:       nil, | ||||
| 		wantDeadline: longCtxDeadline, | ||||
| 	}, { | ||||
| 		name:         "nil config", | ||||
| 		md:           metadata.MD{}, | ||||
| 		config:       nil, | ||||
| 		wantMD:       nil, | ||||
| 		wantDeadline: ctxDeadline, | ||||
| 	}, { | ||||
| 		name: "alter timeout via method config; remove MD", | ||||
| 		md:   testMD, | ||||
| 		config: &iresolver.RPCConfig{ | ||||
| 			MethodConfig: serviceconfig.MethodConfig{ | ||||
| 				Timeout: &shorterTimeout, | ||||
| 			}, | ||||
| 		}, | ||||
| 		wantMD:      nil, | ||||
| 		wantTimeout: shorterTimeout, | ||||
| 	}, { | ||||
| 		name: "onCommitted callback", | ||||
| 		md:   testMD, | ||||
| 		config: &iresolver.RPCConfig{ | ||||
| 			OnCommitted: func() { | ||||
| 				onCommittedCalled = true | ||||
| 			}, | ||||
| 		}, | ||||
| 		wantMD:       testMD, | ||||
| 		wantDeadline: ctxDeadline, | ||||
| 	}} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			var gotInfo *iresolver.RPCInfo | ||||
| 			state := iresolver.SetConfigSelector(resolver.State{ | ||||
| 				Addresses:     []resolver.Address{{Addr: ss.address}}, | ||||
| 				ServiceConfig: parseCfg(ss.r, "{}"), | ||||
| 			}, funcConfigSelector{ | ||||
| 				f: func(i iresolver.RPCInfo) *iresolver.RPCConfig { | ||||
| 					gotInfo = &i | ||||
| 					cfg := tc.config | ||||
| 					if cfg != nil && cfg.Context == nil { | ||||
| 						cfg.Context = i.Context | ||||
| 					} | ||||
| 					return cfg | ||||
| 				}, | ||||
| 			}) | ||||
| 			ss.r.UpdateState(state) // Blocks until config selector is applied
 | ||||
| 
 | ||||
| 			onCommittedCalled = false | ||||
| 			ctx := metadata.NewOutgoingContext(ctx, tc.md) | ||||
| 			startTime := time.Now() | ||||
| 			if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); err != nil { | ||||
| 				t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, nil", err) | ||||
| 			} | ||||
| 
 | ||||
| 			if gotInfo == nil { | ||||
| 				t.Fatalf("no config selector data") | ||||
| 			} | ||||
| 
 | ||||
| 			if want := "/grpc.testing.TestService/EmptyCall"; gotInfo.Method != want { | ||||
| 				t.Errorf("gotInfo.Method = %q; want %q", gotInfo.Method, want) | ||||
| 			} | ||||
| 
 | ||||
| 			gotContextI, ok := gotContextChan.ReceiveOrFail() | ||||
| 			if !ok { | ||||
| 				t.Fatalf("no context received") | ||||
| 			} | ||||
| 			gotContext := gotContextI.(context.Context) | ||||
| 
 | ||||
| 			gotMD, _ := metadata.FromOutgoingContext(gotInfo.Context) | ||||
| 			if diff := cmp.Diff(tc.md, gotMD); diff != "" { | ||||
| 				t.Errorf("gotInfo.Context contains MD %v; want %v\nDiffs: %v", gotMD, tc.md, diff) | ||||
| 			} | ||||
| 
 | ||||
| 			gotMD, _ = metadata.FromIncomingContext(gotContext) | ||||
| 			// Remove entries from gotMD not in tc.wantMD (e.g. authority header).
 | ||||
| 			for k := range gotMD { | ||||
| 				if _, ok := tc.wantMD[k]; !ok { | ||||
| 					delete(gotMD, k) | ||||
| 				} | ||||
| 			} | ||||
| 			if diff := cmp.Diff(tc.wantMD, gotMD, cmpopts.EquateEmpty()); diff != "" { | ||||
| 				t.Errorf("received md = %v; want %v\nDiffs: %v", gotMD, tc.wantMD, diff) | ||||
| 			} | ||||
| 
 | ||||
| 			wantDeadline := tc.wantDeadline | ||||
| 			if wantDeadline == (time.Time{}) { | ||||
| 				wantDeadline = startTime.Add(tc.wantTimeout) | ||||
| 			} | ||||
| 			deadlineGot, _ := gotContext.Deadline() | ||||
| 			if diff := deadlineGot.Sub(wantDeadline); diff > time.Second || diff < -time.Second { | ||||
| 				t.Errorf("received deadline = %v; want ~%v", deadlineGot, wantDeadline) | ||||
| 			} | ||||
| 
 | ||||
| 			if tc.config != nil && tc.config.OnCommitted != nil && !onCommittedCalled { | ||||
| 				t.Errorf("OnCommitted callback not called") | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
		Loading…
	
		Reference in New Issue