From f7a473c5cd33deae8245d06be637489cc5b24896 Mon Sep 17 00:00:00 2001 From: Thor Date: Wed, 10 Apr 2019 13:36:35 -0500 Subject: [PATCH] added context timeout to chooseProtocol util Signed-off-by: Thor --- prober/dns.go | 2 +- prober/http.go | 2 +- prober/icmp.go | 2 +- prober/tcp.go | 2 +- prober/utils.go | 59 +++++++++++++++++++++++++++++++------------------ 5 files changed, 41 insertions(+), 26 deletions(-) diff --git a/prober/dns.go b/prober/dns.go index 471331c2..00e64673 100644 --- a/prober/dns.go +++ b/prober/dns.go @@ -129,7 +129,7 @@ func ProbeDNS(ctx context.Context, target string, module config.Module, registry port = "53" targetAddr = target } - ip, _, err = chooseProtocol(module.DNS.IPProtocol, module.DNS.IPProtocolFallback, targetAddr, registry, logger) + ip, _, err = chooseProtocol(ctx, module.DNS.IPProtocol, module.DNS.IPProtocolFallback, targetAddr, registry, logger) if err != nil { level.Error(logger).Log("msg", "Error resolving address", "err", err) return false diff --git a/prober/http.go b/prober/http.go index 1c5ed236..3d649b5a 100644 --- a/prober/http.go +++ b/prober/http.go @@ -271,7 +271,7 @@ func ProbeHTTP(ctx context.Context, target string, module config.Module, registr targetHost = targetURL.Host } - ip, lookupTime, err := chooseProtocol(module.HTTP.IPProtocol, module.HTTP.IPProtocolFallback, targetHost, registry, logger) + ip, lookupTime, err := chooseProtocol(ctx, module.HTTP.IPProtocol, module.HTTP.IPProtocolFallback, targetHost, registry, logger) if err != nil { level.Error(logger).Log("msg", "Error resolving address", "err", err) return false diff --git a/prober/icmp.go b/prober/icmp.go index 0f1c8198..83c8736b 100644 --- a/prober/icmp.go +++ b/prober/icmp.go @@ -79,7 +79,7 @@ func ProbeICMP(ctx context.Context, target string, module config.Module, registr registry.MustRegister(durationGaugeVec) - ip, lookupTime, err := chooseProtocol(module.ICMP.IPProtocol, module.ICMP.IPProtocolFallback, target, registry, logger) + ip, lookupTime, err := chooseProtocol(ctx, module.ICMP.IPProtocol, module.ICMP.IPProtocolFallback, target, registry, logger) if err != nil { level.Warn(logger).Log("msg", "Error resolving address", "err", err) return false diff --git a/prober/tcp.go b/prober/tcp.go index f0ad879e..e4976256 100644 --- a/prober/tcp.go +++ b/prober/tcp.go @@ -38,7 +38,7 @@ func dialTCP(ctx context.Context, target string, module config.Module, registry return nil, err } - ip, _, err := chooseProtocol(module.TCP.IPProtocol, module.TCP.IPProtocolFallback, targetAddress, registry, logger) + ip, _, err := chooseProtocol(ctx, module.TCP.IPProtocol, module.TCP.IPProtocolFallback, targetAddress, registry, logger) if err != nil { level.Error(logger).Log("msg", "Error resolving address", "err", err) return nil, err diff --git a/prober/utils.go b/prober/utils.go index 848f5ee5..8c9a72b5 100644 --- a/prober/utils.go +++ b/prober/utils.go @@ -14,6 +14,7 @@ package prober import ( + "context" "net" "time" @@ -24,7 +25,7 @@ import ( ) // Returns the IP for the IPProtocol and lookup time. -func chooseProtocol(IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger log.Logger) (ip *net.IPAddr, lookupTime float64, err error) { +func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger log.Logger) (ip *net.IPAddr, lookupTime float64, err error) { var fallbackProtocol string probeDNSLookupTimeSeconds := prometheus.NewGauge(prometheus.GaugeOpts{ Name: "probe_dns_lookup_time_seconds", @@ -60,31 +61,45 @@ func chooseProtocol(IPProtocol string, fallbackIPProtocol bool, target string, r probeDNSLookupTimeSeconds.Add(lookupTime) }() - ip, err = net.ResolveIPAddr(IPProtocol, target) - if err != nil { - if !fallbackIPProtocol { - level.Error(logger).Log("msg", "Resolution with IP protocol failed (fallback_ip_protocol is false):", "err", err) - } else { - level.Warn(logger).Log("msg", "Resolution with IP protocol failed, attempting fallback protocol", "fallback_protocol", fallbackProtocol, "err", err) - ip, err = net.ResolveIPAddr(fallbackProtocol, target) - } - + ipC := make(chan *net.IPAddr, 1) + errors := make(chan error, 1) + go func() { + defer close(ipC) + defer close(errors) + ip, err := net.ResolveIPAddr(IPProtocol, target) if err != nil { - if IPProtocol == "ip6" { - probeIPProtocolGauge.Set(6) + if !fallbackIPProtocol { + level.Error(logger).Log("msg", "Resolution with IP protocol failed (fallback_ip_protocol is false):", "err", err) } else { - probeIPProtocolGauge.Set(4) + level.Warn(logger).Log("msg", "Resolution with IP protocol failed, attempting fallback protocol", "fallback_protocol", fallbackProtocol, "err", err) + ip, err = net.ResolveIPAddr(fallbackProtocol, target) + } + + if err != nil { + if IPProtocol == "ip6" { + probeIPProtocolGauge.Set(6) + } else { + probeIPProtocolGauge.Set(4) + } + errors <- err + return } - return ip, 0.0, err } - } - if ip.IP.To4() == nil { - probeIPProtocolGauge.Set(6) - } else { - probeIPProtocolGauge.Set(4) - } + if ip.IP.To4() == nil { + probeIPProtocolGauge.Set(6) + } else { + probeIPProtocolGauge.Set(4) + } - level.Info(logger).Log("msg", "Resolved target address", "ip", ip) - return ip, lookupTime, nil + ipC <- ip + }() + + select { + case <-ctx.Done(): + return nil, lookupTime, ctx.Err() + case ip := <-ipC: + level.Info(logger).Log("msg", "Resolved target address", "ip", ip) + return ip, lookupTime, nil + } }