diff --git a/plugins/inputs/ping/ping.go b/plugins/inputs/ping/ping.go index 17767bac397b9..008cfceacc5b9 100644 --- a/plugins/inputs/ping/ping.go +++ b/plugins/inputs/ping/ping.go @@ -23,6 +23,10 @@ import ( // for unit test purposes (see ping_test.go) type HostPinger func(binary string, timeout float64, args ...string) (string, error) +type HostResolver func(ctx context.Context, ipv6 bool, host string) (*net.IPAddr, error) + +type IsCorrectNetwork func(ip net.IPAddr) bool + type Ping struct { wg sync.WaitGroup @@ -60,6 +64,9 @@ type Ping struct { // host ping function pingHost HostPinger + // resolve host function + resolveHost HostResolver + // listenAddr is the address associated with the interface defined. listenAddr string } @@ -123,13 +130,6 @@ func (p *Ping) Gather(acc telegraf.Accumulator) error { } for _, host := range p.Urls { - _, err := net.LookupHost(host) - if err != nil { - acc.AddFields("ping", map[string]interface{}{"result_code": 1}, map[string]string{"url": host}) - acc.AddError(err) - continue - } - p.wg.Add(1) go func(host string) { defer p.wg.Done() @@ -194,25 +194,47 @@ func hostPinger(binary string, timeout float64, args ...string) (string, error) return string(out), err } -func (p *Ping) pingToURLNative(destination string, acc telegraf.Accumulator) { - ctx := context.Background() - - network := "ip4" - if p.IPv6 { - network = "ip6" +func filterIPs(addrs []net.IPAddr, filterFunc IsCorrectNetwork) []net.IPAddr { + n := 0 + for _, x := range addrs { + if filterFunc(x) { + addrs[n] = x + n++ + } } + return addrs[:n] +} + +func hostResolver(ctx context.Context, ipv6 bool, destination string) (*net.IPAddr, error) { + resolver := &net.Resolver{} + ips, err := resolver.LookupIPAddr(ctx, destination) - host, err := net.ResolveIPAddr(network, destination) if err != nil { - acc.AddFields( - "ping", - map[string]interface{}{"result_code": 1}, - map[string]string{"url": destination}, - ) - acc.AddError(err) - return + return nil, err + } + + if ipv6 { + ips = filterIPs(ips, isV6) + } else { + ips = filterIPs(ips, isV4) + } + + if len(ips) == 0 { + return nil, errors.New("Cannot resolve ip address") } + return &ips[0], err +} + +func isV4(ip net.IPAddr) bool { + return ip.IP.To4() != nil +} + +func isV6(ip net.IPAddr) bool { + return !isV4(ip) +} +func (p *Ping) pingToURLNative(destination string, acc telegraf.Accumulator) { + ctx := context.Background() interval := p.PingInterval if interval < 0.2 { interval = 0.2 @@ -232,6 +254,17 @@ func (p *Ping) pingToURLNative(destination string, acc telegraf.Accumulator) { defer cancel() } + host, err := p.resolveHost(ctx, p.IPv6, destination) + if err != nil { + acc.AddFields( + "ping", + map[string]interface{}{"result_code": 1}, + map[string]string{"url": destination}, + ) + acc.AddError(err) + return + } + resps := make(chan *ping.Response) rsps := []*ping.Response{} @@ -392,6 +425,7 @@ func init() { inputs.Add("ping", func() telegraf.Input { return &Ping{ pingHost: hostPinger, + resolveHost: hostResolver, PingInterval: 1.0, Count: 1, Timeout: 1.0, diff --git a/plugins/inputs/ping/ping_test.go b/plugins/inputs/ping/ping_test.go index 8a1a0a9e1d7ca..d6f78bb793220 100644 --- a/plugins/inputs/ping/ping_test.go +++ b/plugins/inputs/ping/ping_test.go @@ -3,7 +3,9 @@ package ping import ( + "context" "errors" + "net" "reflect" "sort" "testing" @@ -340,6 +342,12 @@ func TestPingBinary(t *testing.T) { acc.GatherError(p.Gather) } +func mockHostResolver(ctx context.Context, ipv6 bool, host string) (*net.IPAddr, error) { + ipaddr := net.IPAddr{} + ipaddr.IP = net.IPv4(127, 0, 0, 1) + return &ipaddr, nil +} + // Test that Gather function works using native ping func TestPingGatherNative(t *testing.T) { if testing.Short() { @@ -348,12 +356,35 @@ func TestPingGatherNative(t *testing.T) { var acc testutil.Accumulator p := Ping{ - Urls: []string{"localhost", "127.0.0.2"}, - Method: "native", - Count: 5, + Urls: []string{"localhost", "127.0.0.2"}, + Method: "native", + Count: 5, + resolveHost: mockHostResolver, } assert.NoError(t, acc.GatherError(p.Gather)) assert.True(t, acc.HasPoint("ping", map[string]string{"url": "localhost"}, "packets_transmitted", 5)) assert.True(t, acc.HasPoint("ping", map[string]string{"url": "localhost"}, "packets_received", 5)) } + +func mockHostResolverError(ctx context.Context, ipv6 bool, host string) (*net.IPAddr, error) { + return nil, errors.New("myMock error") +} + +// Test failed DNS resolutions +func TestDNSLookupError(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test due to permission requirements.") + } + + var acc testutil.Accumulator + p := Ping{ + Urls: []string{"localhost"}, + Method: "native", + IPv6: false, + resolveHost: mockHostResolverError, + } + + acc.GatherError(p.Gather) + assert.True(t, len(acc.Errors) > 0) +}