diff --git a/controller/destination/server.go b/controller/destination/server.go index 966e7be47..442bcebe7 100644 --- a/controller/destination/server.go +++ b/controller/destination/server.go @@ -109,6 +109,13 @@ func (s *server) Get(dest *common.Destination, stream pb.Destination_GetServer) } } + // If this is an IP address, echo it back + isIP, ip := isIPAddress(host) + if isIP { + echoIPDestination(ip, port, stream) + return nil + } + id, err := s.localKubernetesServiceIdFromDNSName(host) if err != nil { log.Error(err) @@ -126,6 +133,34 @@ func (s *server) Get(dest *common.Destination, stream pb.Destination_GetServer) return err } +func isIPAddress(host string) (bool, *common.IPAddress) { + ip, err := util.ParseIPV4(host) + return err == nil, ip +} + +func echoIPDestination(ip *common.IPAddress, port int, stream pb.Destination_GetServer) bool { + update := &pb.Update{ + Update: &pb.Update_Add{ + Add: &pb.WeightedAddrSet{ + Addrs: []*pb.WeightedAddr{ + &pb.WeightedAddr{ + Addr: &common.TcpAddress{ + Ip: ip, + Port: uint32(port), + }, + Weight: 1, + }, + }, + }, + }, + } + stream.Send(update) + + <-stream.Context().Done() + + return true +} + func (s *server) resolveKubernetesService(id string, port int, stream pb.Destination_GetServer) error { listener := endpointListener{stream: stream} diff --git a/controller/destination/server_test.go b/controller/destination/server_test.go index af7d8fb79..ffc6f376d 100644 --- a/controller/destination/server_test.go +++ b/controller/destination/server_test.go @@ -2,8 +2,9 @@ package destination import ( "fmt" - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestLocalKubernetesServiceIdFromDNSName(t *testing.T) { @@ -12,7 +13,7 @@ func TestLocalKubernetesServiceIdFromDNSName(t *testing.T) { testCases := []struct { k8sDNSZone string host string - result *string + result *string resultErr bool }{ {"cluster.local", "", nil, true}, @@ -88,3 +89,22 @@ func TestSplitDNSName(t *testing.T) { }) } } + +func TestIsIPAddress(t *testing.T) { + testCases := []struct { + host string + result bool + }{ + {"8.8.8.8", true}, + {"example.com", false}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d: %+v", i, tc.host), func(t *testing.T) { + isIP, _ := isIPAddress(tc.host) + if isIP != tc.result { + t.Fatalf("Unexpected result: %+v", isIP) + } + }) + } +}