diff --git a/prober/utils.go b/prober/utils.go index 98c9152e..365cb162 100644 --- a/prober/utils.go +++ b/prober/utils.go @@ -15,6 +15,7 @@ package prober import ( "context" + "errors" "fmt" "hash/fnv" "net" @@ -26,6 +27,28 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +var protocolToGauge = map[string]float64{ + "ip4": 4, + "ip6": 6, +} + +type resolver struct { + net.Resolver +} + +// A simple wrapper around resolver.LookupIP. +func (r *resolver) resolve(ctx context.Context, target string, protocol string) (*net.IPAddr, error) { + ips, err := r.LookupIP(ctx, protocol, target) + if err != nil { + return nil, err + } + for _, ip := range ips { + return &net.IPAddr{IP: ip}, nil + } + // Go doc did not specify when this could happen, better be defensive. + return nil, errors.New("calling LookupIP returned empty list of addresses") +} + // Returns the IP for the IPProtocol and lookup time. func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger log.Logger) (ip *net.IPAddr, lookupTime float64, err error) { var fallbackProtocol string @@ -55,7 +78,6 @@ func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol b fallbackProtocol = "ip6" } - level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", IPProtocol) resolveStart := time.Now() defer func() { @@ -63,55 +85,33 @@ func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol b probeDNSLookupTimeSeconds.Add(lookupTime) }() - resolver := &net.Resolver{} - ips, err := resolver.LookupIPAddr(ctx, target) - if err != nil { - level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err) - return nil, 0.0, err - } - - // Return the IP in the requested protocol. - var fallback *net.IPAddr - for _, ip := range ips { - switch IPProtocol { - case "ip4": - if ip.IP.To4() != nil { - level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String()) - probeIPProtocolGauge.Set(4) - probeIPAddrHash.Set(ipHash(ip.IP)) - return &ip, lookupTime, nil - } - - // ip4 as fallback - fallback = &ip - - case "ip6": - if ip.IP.To4() == nil { - level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String()) - probeIPProtocolGauge.Set(6) - probeIPAddrHash.Set(ipHash(ip.IP)) - return &ip, lookupTime, nil - } - - // ip6 as fallback - fallback = &ip - } + r := &resolver{ + Resolver: net.Resolver{}, } - // Unable to find ip and no fallback set. - if fallback == nil || !fallbackIPProtocol { - return nil, 0.0, fmt.Errorf("unable to find ip; no fallback") + level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", IPProtocol) + if ip, err := r.resolve(ctx, target, IPProtocol); err == nil { + level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String()) + probeIPProtocolGauge.Set(protocolToGauge[IPProtocol]) + probeIPAddrHash.Set(ipHash(ip.IP)) + return ip, lookupTime, nil + } else if !fallbackIPProtocol { + level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err) + return nil, 0.0, fmt.Errorf("unable to find ip; no fallback: %s", err) } - // Use fallback ip protocol. - if fallbackProtocol == "ip4" { - probeIPProtocolGauge.Set(4) - } else { - probeIPProtocolGauge.Set(6) + level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", fallbackProtocol) + ip, err = r.resolve(ctx, target, fallbackProtocol) + if err != nil { + // This could happen when the domain don't have A and AAAA record (e.g. + // only have MX record). + level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err) + return nil, 0.0, fmt.Errorf("unable to find ip; exhausted fallback: %s", err) } - probeIPAddrHash.Set(ipHash(fallback.IP)) - level.Info(logger).Log("msg", "Resolved target address", "ip", fallback.String()) - return fallback, lookupTime, nil + level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String()) + probeIPProtocolGauge.Set(protocolToGauge[fallbackProtocol]) + probeIPAddrHash.Set(ipHash(ip.IP)) + return ip, lookupTime, nil } func ipHash(ip net.IP) float64 { diff --git a/prober/utils_test.go b/prober/utils_test.go index a53ee490..b7532b35 100644 --- a/prober/utils_test.go +++ b/prober/utils_test.go @@ -24,6 +24,7 @@ import ( "math/big" "net" "os" + "strings" "testing" "time" @@ -162,7 +163,7 @@ func TestChooseProtocol(t *testing.T) { registry = prometheus.NewPedanticRegistry() ip, _, err = chooseProtocol(ctx, "ip4", false, "ipv6.google.com", registry, logger) - if err != nil && err.Error() != "unable to find ip; no fallback" { + if err != nil && !strings.HasPrefix(err.Error(), "unable to find ip; no fallback") { t.Error(err) } else if err == nil { t.Error("should set error") @@ -170,6 +171,17 @@ func TestChooseProtocol(t *testing.T) { if ip != nil { t.Error("without fallback it should not answer") } + + registry = prometheus.NewPedanticRegistry() + ip, _, err = chooseProtocol(ctx, "ip4", true, "does-not-exist.google.com", registry, logger) + if err != nil && !strings.HasPrefix(err.Error(), "unable to find ip; exhausted fallback") { + t.Error(err) + } else if err == nil { + t.Error("should set error") + } + if ip != nil { + t.Error("with exhausted fallback it should not answer") + } } func checkMetrics(expected map[string]map[string]map[string]struct{}, mfs []*dto.MetricFamily, t *testing.T) {