diff --git a/core/src/main/java/io/grpc/internal/JndiResourceResolverFactory.java b/core/src/main/java/io/grpc/internal/JndiResourceResolverFactory.java index b095266a04..5e76438b9a 100644 --- a/core/src/main/java/io/grpc/internal/JndiResourceResolverFactory.java +++ b/core/src/main/java/io/grpc/internal/JndiResourceResolverFactory.java @@ -86,7 +86,7 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol if (unavailabilityCause() != null) { return null; } - return new JndiResourceResolver(); + return new JndiResourceResolver(new JndiRecordFetcher()); } @Nullable @@ -95,6 +95,11 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol return JNDI_UNAVAILABILITY_CAUSE; } + @VisibleForTesting + interface RecordFetcher { + List getAllRecords(String recordType, String name) throws NamingException; + } + @VisibleForTesting static final class JndiResourceResolver implements DnsNameResolver.ResourceResolver { private static final Logger logger = @@ -102,15 +107,20 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol private static final Pattern whitespace = Pattern.compile("\\s+"); + private final RecordFetcher recordFetcher; + + public JndiResourceResolver(RecordFetcher recordFetcher) { + this.recordFetcher = recordFetcher; + } + @Override public List resolveTxt(String serviceConfigHostname) throws NamingException { - checkAvailable(); if (logger.isLoggable(Level.FINER)) { logger.log( Level.FINER, "About to query TXT records for {0}", new Object[]{serviceConfigHostname}); } List serviceConfigRawTxtRecords = - getAllRecords("TXT", "dns:///" + serviceConfigHostname); + recordFetcher.getAllRecords("TXT", "dns:///" + serviceConfigHostname); if (logger.isLoggable(Level.FINER)) { logger.log( Level.FINER, "Found {0} TXT records", new Object[]{serviceConfigRawTxtRecords.size()}); @@ -126,13 +136,12 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol @Override public List resolveSrv( AddressResolver addressResolver, String grpclbHostname) throws Exception { - checkAvailable(); if (logger.isLoggable(Level.FINER)) { logger.log( Level.FINER, "About to query SRV records for {0}", new Object[]{grpclbHostname}); } List grpclbSrvRecords = - getAllRecords("SRV", "dns:///" + grpclbHostname); + recordFetcher.getAllRecords("SRV", "dns:///" + grpclbHostname); if (logger.isLoggable(Level.FINER)) { logger.log( Level.FINER, "Found {0} SRV records", new Object[]{grpclbSrvRecords.size()}); @@ -144,14 +153,23 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol for (String srvRecord : grpclbSrvRecords) { try { SrvRecord record = parseSrvRecord(srvRecord); + // SRV requires the host name to be absolute + if (!record.host.endsWith(".")) { + throw new RuntimeException("Returned SRV host does not end in period: " + record.host); + } + // Strip trailing dot for appearance's sake. It _should_ be fine either way, but most + // people expect to see it without the dot. + String authority = record.host.substring(0, record.host.length() - 1); + // But we want to use the trailing dot for the IP lookup. The dot makes the name absolute + // instead of relative and so will avoid the search list like that in resolv.conf. List addrs = addressResolver.resolveAddress(record.host); List sockaddrs = new ArrayList<>(addrs.size()); for (InetAddress addr : addrs) { sockaddrs.add(new InetSocketAddress(addr, record.port)); } Attributes attrs = Attributes.newBuilder() - .set(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY, record.host) + .set(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY, authority) .build(); balancerAddresses.add( new EquivalentAddressGroup(Collections.unmodifiableList(sockaddrs), attrs)); @@ -176,8 +194,7 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol return Collections.unmodifiableList(balancerAddresses); } - @VisibleForTesting - static final class SrvRecord { + private static final class SrvRecord { SrvRecord(String host, int port) { this.host = host; this.port = port; @@ -187,17 +204,50 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol final int port; } - @VisibleForTesting @SuppressWarnings("BetaApi") // Verify is only kinda beta - static SrvRecord parseSrvRecord(String rawRecord) { + private static SrvRecord parseSrvRecord(String rawRecord) { String[] parts = whitespace.split(rawRecord); Verify.verify(parts.length == 4, "Bad SRV Record: %s", rawRecord); return new SrvRecord(parts[3], Integer.parseInt(parts[2])); } - @IgnoreJRERequirement - private static List getAllRecords(String recordType, String name) - throws NamingException { + /** + * Undo the quoting done in {@link com.sun.jndi.dns.ResourceRecord#decodeTxt}. + */ + @VisibleForTesting + static String unquote(String txtRecord) { + StringBuilder sb = new StringBuilder(txtRecord.length()); + boolean inquote = false; + for (int i = 0; i < txtRecord.length(); i++) { + char c = txtRecord.charAt(i); + if (!inquote) { + if (c == ' ') { + continue; + } else if (c == '"') { + inquote = true; + continue; + } + } else { + if (c == '"') { + inquote = false; + continue; + } else if (c == '\\') { + c = txtRecord.charAt(++i); + assert c == '"' || c == '\\'; + } + } + sb.append(c); + } + return sb.toString(); + } + } + + @VisibleForTesting + @IgnoreJRERequirement + static final class JndiRecordFetcher implements RecordFetcher { + @Override + public List getAllRecords(String recordType, String name) throws NamingException { + checkAvailable(); String[] rrType = new String[]{recordType}; List records = new ArrayList<>(); @@ -237,7 +287,6 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol return records; } - @IgnoreJRERequirement private static void closeThenThrow(NamingEnumeration namingEnumeration, NamingException e) throws NamingException { try { @@ -248,7 +297,6 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol throw e; } - @IgnoreJRERequirement private static void closeThenThrow(DirContext ctx, NamingException e) throws NamingException { try { ctx.close(); @@ -258,36 +306,6 @@ final class JndiResourceResolverFactory implements DnsNameResolver.ResourceResol throw e; } - /** - * Undo the quoting done in {@link com.sun.jndi.dns.ResourceRecord#decodeTxt}. - */ - @VisibleForTesting - static String unquote(String txtRecord) { - StringBuilder sb = new StringBuilder(txtRecord.length()); - boolean inquote = false; - for (int i = 0; i < txtRecord.length(); i++) { - char c = txtRecord.charAt(i); - if (!inquote) { - if (c == ' ') { - continue; - } else if (c == '"') { - inquote = true; - continue; - } - } else { - if (c == '"') { - inquote = false; - continue; - } else if (c == '\\') { - c = txtRecord.charAt(++i); - assert c == '"' || c == '\\'; - } - } - sb.append(c); - } - return sb.toString(); - } - private static void checkAvailable() { if (JNDI_UNAVAILABILITY_CAUSE != null) { throw new UnsupportedOperationException( diff --git a/core/src/test/java/io/grpc/internal/JndiResourceResolverTest.java b/core/src/test/java/io/grpc/internal/JndiResourceResolverTest.java index 51ae7c5d2c..c2e9111a50 100644 --- a/core/src/test/java/io/grpc/internal/JndiResourceResolverTest.java +++ b/core/src/test/java/io/grpc/internal/JndiResourceResolverTest.java @@ -18,11 +18,21 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.internal.DnsNameResolver.AddressResolver; +import io.grpc.internal.GrpcAttributes; +import io.grpc.internal.JndiResourceResolverFactory.JndiRecordFetcher; import io.grpc.internal.JndiResourceResolverFactory.JndiResourceResolver; -import io.grpc.internal.JndiResourceResolverFactory.JndiResourceResolver.SrvRecord; +import io.grpc.internal.JndiResourceResolverFactory.RecordFetcher; import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.UnknownHostException; +import java.util.Arrays; import java.util.List; import org.junit.Assume; import org.junit.Test; @@ -50,15 +60,9 @@ public class JndiResourceResolverTest { public void jndiResolverWorks() throws Exception { Assume.assumeNoException(new JndiResourceResolverFactory().unavailabilityCause()); - AddressResolver addressResolver = new AddressResolver() { - @Override - public List resolveAddress(String host) throws Exception { - return null; - } - }; - JndiResourceResolver resolver = new JndiResourceResolver(); + RecordFetcher recordFetcher = new JndiRecordFetcher(); try { - resolver.resolveSrv(addressResolver, "localhost"); + recordFetcher.getAllRecords("SRV", "dns:///localhost"); } catch (javax.naming.CommunicationException e) { Assume.assumeNoException(e); } catch (javax.naming.NameNotFoundException e) { @@ -67,9 +71,45 @@ public class JndiResourceResolverTest { } @Test - public void parseSrvRecord() { - SrvRecord record = JndiResourceResolver.parseSrvRecord("0 0 1234 foo.bar.com"); - assertThat(record.host).isEqualTo("foo.bar.com"); - assertThat(record.port).isEqualTo(1234); + public void txtRecordLookup() throws Exception { + RecordFetcher recordFetcher = mock(RecordFetcher.class); + when(recordFetcher.getAllRecords("TXT", "dns:///service.example.com")) + .thenReturn(Arrays.asList("foo", "\"bar\"")); + + List golden = Arrays.asList("foo", "bar"); + JndiResourceResolver resolver = new JndiResourceResolver(recordFetcher); + assertThat(resolver.resolveTxt("service.example.com")).isEqualTo(golden); + } + + @Test + public void srvRecordLookup() throws Exception { + AddressResolver addressResolver = mock(AddressResolver.class); + when(addressResolver.resolveAddress("foo.example.com.")) + .thenReturn(Arrays.asList(InetAddress.getByName("127.1.2.3"))); + when(addressResolver.resolveAddress("bar.example.com.")) + .thenReturn(Arrays.asList( + InetAddress.getByName("127.3.2.1"), InetAddress.getByName("::1"))); + when(addressResolver.resolveAddress("unknown.example.com.")) + .thenThrow(new UnknownHostException("unknown.example.com.")); + RecordFetcher recordFetcher = mock(RecordFetcher.class); + when(recordFetcher.getAllRecords("SRV", "dns:///service.example.com")) + .thenReturn(Arrays.asList( + "0 0 314 foo.example.com.", "0 0 42 bar.example.com.", "0 0 1 unknown.example.com.")); + + List golden = Arrays.asList( + new EquivalentAddressGroup( + Arrays.asList(new InetSocketAddress("127.1.2.3", 314)), + Attributes.newBuilder() + .set(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY, "foo.example.com") + .build()), + new EquivalentAddressGroup( + Arrays.asList( + new InetSocketAddress("127.3.2.1", 42), + new InetSocketAddress("::1", 42)), + Attributes.newBuilder() + .set(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY, "bar.example.com") + .build())); + JndiResourceResolver resolver = new JndiResourceResolver(recordFetcher); + assertThat(resolver.resolveSrv(addressResolver, "service.example.com")).isEqualTo(golden); } }