diff --git a/clientconn.go b/clientconn.go index a2afb4668..b50c698a0 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1537,6 +1537,9 @@ func (c *channelzChannel) ChannelzMetric() *channelz.ChannelInternalMetric { // referenced by users. var ErrClientConnTimeout = errors.New("grpc: timed out when dialing") +// getResolver finds the scheme in the cc's resolvers or the global registry. +// scheme should always be lowercase (typically by virtue of url.Parse() +// performing proper RFC3986 behavior). func (cc *ClientConn) getResolver(scheme string) resolver.Builder { for _, rb := range cc.dopts.resolvers { if scheme == rb.Scheme() { diff --git a/resolver/resolver.go b/resolver/resolver.go index eb6a46909..6215e5ef2 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -41,8 +41,9 @@ var ( // TODO(bar) install dns resolver in init(){}. -// Register registers the resolver builder to the resolver map. b.Scheme will be -// used as the scheme registered with this builder. +// Register registers the resolver builder to the resolver map. b.Scheme will +// be used as the scheme registered with this builder. The registry is case +// sensitive, and schemes should not contain any uppercase characters. // // NOTE: this function must only be called during initialization time (i.e. in // an init() function), and is not thread-safe. If multiple Resolvers are @@ -289,8 +290,10 @@ type Builder interface { // gRPC dial calls Build synchronously, and fails if the returned error is // not nil. Build(target Target, cc ClientConn, opts BuildOptions) (Resolver, error) - // Scheme returns the scheme supported by this resolver. - // Scheme is defined at https://github.com/grpc/grpc/blob/master/doc/naming.md. + // Scheme returns the scheme supported by this resolver. Scheme is defined + // at https://github.com/grpc/grpc/blob/master/doc/naming.md. The returned + // string should not contain uppercase characters, as they will not match + // the parsed target's scheme as defined in RFC 3986. Scheme() string } diff --git a/resolver_test.go b/resolver_test.go new file mode 100644 index 000000000..5b1e40c2a --- /dev/null +++ b/resolver_test.go @@ -0,0 +1,93 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "context" + "fmt" + "net" + "testing" + + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/resolver" +) + +type wrapResolverBuilder struct { + resolver.Builder + scheme string +} + +func (w *wrapResolverBuilder) Scheme() string { + return w.scheme +} + +func init() { + resolver.Register(&wrapResolverBuilder{Builder: resolver.Get("passthrough"), scheme: "casetest"}) + resolver.Register(&wrapResolverBuilder{Builder: resolver.Get("dns"), scheme: "caseTest"}) +} + +func (s) TestResolverCaseSensitivity(t *testing.T) { + // This should find the "casetest" resolver instead of the "caseTest" + // resolver, even though the latter was registered later. "casetest" is + // "passthrough" and "caseTest" is "dns". With "passthrough" the dialer + // should see the target's address directly, but "dns" would be converted + // into a loopback IP (v4 or v6) address. + target := "caseTest:///localhost:1234" + addrCh := make(chan string, 1) + customDialer := func(ctx context.Context, addr string) (net.Conn, error) { + select { + case addrCh <- addr: + default: + } + return nil, fmt.Errorf("not dialing with custom dialer") + } + + cc, err := Dial(target, WithContextDialer(customDialer), WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("Unexpected Dial(%q) error: %v", target, err) + } + cc.Connect() + if got, want := <-addrCh, "localhost:1234"; got != want { + cc.Close() + t.Fatalf("Dialer got address %q; wanted %q", got, want) + } + cc.Close() + + // Clear addrCh for future use. + select { + case <-addrCh: + default: + } + + res := &wrapResolverBuilder{Builder: resolver.Get("dns"), scheme: "caseTest2"} + // This should not find the injected resolver due to the case not matching. + // This results in "passthrough" being used with the address as the whole + // target. + target = "caseTest2:///localhost:1234" + cc, err = Dial(target, WithContextDialer(customDialer), WithResolvers(res), WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("Unexpected Dial(%q) error: %v", target, err) + } + cc.Connect() + if got, want := <-addrCh, target; got != want { + cc.Close() + t.Fatalf("Dialer got address %q; wanted %q", got, want) + } + cc.Close() +}