diff --git a/prober/utils.go b/prober/utils.go index 98c9152e..39aef270 100644 --- a/prober/utils.go +++ b/prober/utils.go @@ -26,8 +26,13 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +var protocolToGauge = map[string]float64{ + "ip4": 4, + "ip6": 6, +} + // Returns the IP for the IPProtocol and lookup time. -func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger log.Logger) (ip *net.IPAddr, lookupTime float64, err error) { +func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger log.Logger) (ip *net.IPAddr, lookupTime float64, returnerr error) { var fallbackProtocol string probeDNSLookupTimeSeconds := prometheus.NewGauge(prometheus.GaugeOpts{ Name: "probe_dns_lookup_time_seconds", @@ -54,64 +59,48 @@ func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol b IPProtocol = "ip4" fallbackProtocol = "ip6" } + var usedProtocol string - level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", IPProtocol) resolveStart := time.Now() defer func() { lookupTime = time.Since(resolveStart).Seconds() probeDNSLookupTimeSeconds.Add(lookupTime) + if usedProtocol != "" { + probeIPProtocolGauge.Set(protocolToGauge[usedProtocol]) + } + if ip != nil { + probeIPAddrHash.Set(ipHash(ip.IP)) + } }() resolver := &net.Resolver{} - ips, err := resolver.LookupIPAddr(ctx, target) - if err != nil { - level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err) - return nil, 0.0, err - } - - // Return the IP in the requested protocol. - var fallback *net.IPAddr - for _, ip := range ips { - switch IPProtocol { - case "ip4": - if ip.IP.To4() != nil { - level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String()) - probeIPProtocolGauge.Set(4) - probeIPAddrHash.Set(ipHash(ip.IP)) - return &ip, lookupTime, nil - } - - // ip4 as fallback - fallback = &ip - case "ip6": - if ip.IP.To4() == nil { - level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String()) - probeIPProtocolGauge.Set(6) - probeIPAddrHash.Set(ipHash(ip.IP)) - return &ip, lookupTime, nil - } - - // ip6 as fallback - fallback = &ip - } - } - - // Unable to find ip and no fallback set. - if fallback == nil || !fallbackIPProtocol { - return nil, 0.0, fmt.Errorf("unable to find ip; no fallback") + level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", IPProtocol) + if ips, err := resolver.LookupIP(ctx, IPProtocol, target); err == nil { + level.Info(logger).Log("msg", "Resolved target address", "ip", ips[0].String()) + usedProtocol = IPProtocol + ip = &net.IPAddr{IP: ips[0]} + return + } else if !fallbackIPProtocol { + level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err) + returnerr = fmt.Errorf("unable to find ip; no fallback: %s", err) + return } - // Use fallback ip protocol. - if fallbackProtocol == "ip4" { - probeIPProtocolGauge.Set(4) - } else { - probeIPProtocolGauge.Set(6) + level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", fallbackProtocol) + ips, err := resolver.LookupIP(ctx, fallbackProtocol, target) + if err != nil { + // This could happen when the domain don't have A and AAAA record (e.g. + // only have MX record). + level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err) + returnerr = fmt.Errorf("unable to find ip; exhausted fallback: %s", err) + return } - probeIPAddrHash.Set(ipHash(fallback.IP)) - level.Info(logger).Log("msg", "Resolved target address", "ip", fallback.String()) - return fallback, lookupTime, nil + level.Info(logger).Log("msg", "Resolved target address", "ip", ips[0].String()) + usedProtocol = fallbackProtocol + ip = &net.IPAddr{IP: ips[0]} + return } func ipHash(ip net.IP) float64 { diff --git a/prober/utils_test.go b/prober/utils_test.go index a53ee490..bf21dc79 100644 --- a/prober/utils_test.go +++ b/prober/utils_test.go @@ -24,6 +24,7 @@ import ( "math/big" "net" "os" + "strings" "testing" "time" @@ -162,7 +163,7 @@ func TestChooseProtocol(t *testing.T) { registry = prometheus.NewPedanticRegistry() ip, _, err = chooseProtocol(ctx, "ip4", false, "ipv6.google.com", registry, logger) - if err != nil && err.Error() != "unable to find ip; no fallback" { + if err != nil && !strings.HasPrefix(err.Error(), "unable to find ip; no fallback") { t.Error(err) } else if err == nil { t.Error("should set error") @@ -170,6 +171,17 @@ func TestChooseProtocol(t *testing.T) { if ip != nil { t.Error("without fallback it should not answer") } + + registry = prometheus.NewPedanticRegistry() + ip, _, err = chooseProtocol(ctx, "ip4", true, "does-not-exist.example.com", registry, logger) + if err != nil && !strings.HasPrefix(err.Error(), "unable to find ip; exhausted fallback") { + t.Error(err) + } else if err == nil { + t.Error("should set error") + } + if ip != nil { + t.Error("with exhausted fallback it should not answer") + } } func checkMetrics(expected map[string]map[string]map[string]struct{}, mfs []*dto.MetricFamily, t *testing.T) {