diff --git a/attack.go b/attack.go index 63c8d55a..47276e95 100644 --- a/attack.go +++ b/attack.go @@ -56,6 +56,7 @@ func attackCmd() command { fs.Var(&opts.laddr, "laddr", "Local IP address") fs.BoolVar(&opts.keepalive, "keepalive", true, "Use persistent connections") fs.StringVar(&opts.unixSocket, "unix-socket", "", "Connect over a unix socket. This overrides the host address in target URLs") + fs.Var(&dnsTTLFlag{&opts.dnsTTL}, "dns-ttl", "Cache DNS lookups for the given duration [-1 = disabled, 0 = forever]") systemSpecificFlags(fs, opts) return command{fs, func(args []string) error { @@ -99,6 +100,7 @@ type attackOpts struct { keepalive bool resolvers csl unixSocket string + dnsTTL time.Duration } // attack validates the attack arguments, sets up the @@ -116,6 +118,8 @@ func attack(opts *attackOpts) (err error) { net.DefaultResolver = res } + net.DefaultResolver.PreferGo = true + files := map[string]io.Reader{} for _, filename := range []string{opts.targetsf, opts.bodyf} { if filename == "" { @@ -188,6 +192,7 @@ func attack(opts *attackOpts) (err error) { vegeta.UnixSocket(opts.unixSocket), vegeta.ProxyHeader(proxyHdr), vegeta.ChunkedBody(opts.chunked), + vegeta.DNSCaching(opts.dnsTTL), ) res := atk.Attack(tr, opts.rate, opts.duration, opts.name) diff --git a/flags.go b/flags.go index 6a780d9d..7ac8fcb1 100644 --- a/flags.go +++ b/flags.go @@ -132,3 +132,24 @@ func (f *maxBodyFlag) String() string { } return datasize.ByteSize(*(f.n)).String() } + +type dnsTTLFlag struct{ ttl *time.Duration } + +func (f *dnsTTLFlag) Set(v string) (err error) { + if v == "-1" { + *(f.ttl) = -1 + return nil + } + + *(f.ttl), err = time.ParseDuration(v) + return err +} + +func (f *dnsTTLFlag) String() string { + if f.ttl == nil { + return "" + } else if *(f.ttl) == -1 { + return "-1" + } + return f.ttl.String() +} diff --git a/go.mod b/go.mod index a7d3e530..aa769a9b 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,11 @@ require ( require ( github.com/iancoleman/orderedmap v0.3.0 // indirect github.com/josharian/intern v1.0.0 // indirect + github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 // indirect + github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c // indirect + github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92 // indirect golang.org/x/mod v0.8.0 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.10.0 // indirect golang.org/x/text v0.11.0 // indirect golang.org/x/tools v0.6.0 // indirect diff --git a/go.sum b/go.sum index 9967e4ae..7f9099f5 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,12 @@ github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 h1:Lt9DzQALzHoDwMBGJ6v8ObDPR0dzr2a6sXTB1Fq7IHs= +github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA= +github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c h1:aqg5Vm5dwtvL+YgDpBcK1ITf3o96N/K7/wsRXQnUTEs= +github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c/go.mod h1:owqhoLW1qZoYLZzLnBw+QkPP9WZnjlSWihhxAJC1+/M= +github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92 h1:OfRzdxCzDhp+rsKWXuOO2I/quKMJ/+TQwVbIP/gltZg= +github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92/go.mod h1:7/OT02F6S6I7v6WXb+IjhMuZEYfH/RJ5RwEWnEo5BMg= github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs= github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -39,7 +45,9 @@ golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= diff --git a/lib/attack.go b/lib/attack.go index 0bbf68dd..649679a8 100644 --- a/lib/attack.go +++ b/lib/attack.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "math" + "math/rand" "net" "net/http" "net/url" @@ -14,6 +15,7 @@ import ( "sync" "time" + "github.com/rs/dnscache" "golang.org/x/net/http2" ) @@ -27,9 +29,6 @@ type Attacker struct { maxWorkers uint64 maxBody int64 redirects int - seqmu sync.Mutex - seq uint64 - began time.Time chunked bool } @@ -73,7 +72,6 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker { workers: DefaultWorkers, maxWorkers: DefaultMaxWorkers, maxBody: DefaultMaxBody, - began: time.Now(), } a.dialer = &net.Dialer{ @@ -85,7 +83,7 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker { Timeout: DefaultTimeout, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, - Dial: a.dialer.Dial, + DialContext: a.dialer.DialContext, TLSClientConfig: DefaultTLSConfig, MaxIdleConnsPerHost: DefaultConnections, MaxConnsPerHost: DefaultMaxConnections, @@ -177,7 +175,7 @@ func LocalAddr(addr net.IPAddr) func(*Attacker) { return func(a *Attacker) { tr := a.client.Transport.(*http.Transport) a.dialer.LocalAddr = &net.TCPAddr{IP: addr.IP, Zone: addr.Zone} - tr.Dial = a.dialer.Dial + tr.DialContext = a.dialer.DialContext } } @@ -189,7 +187,7 @@ func KeepAlive(keepalive bool) func(*Attacker) { tr.DisableKeepAlives = !keepalive if !keepalive { a.dialer.KeepAlive = 0 - tr.Dial = a.dialer.Dial + tr.DialContext = a.dialer.DialContext } } } @@ -223,8 +221,8 @@ func H2C(enabled bool) func(*Attacker) { if tr := a.client.Transport.(*http.Transport); enabled { a.client.Transport = &http2.Transport{ AllowHTTP: true, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - return tr.Dial(network, addr) + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return tr.DialContext(ctx, network, addr) }, } } @@ -263,6 +261,119 @@ func ProxyHeader(h http.Header) func(*Attacker) { } } +// DNSCaching returns a functional option that enables DNS caching for +// the given ttl. When ttl is zero cached entries will never expire. +// When ttl is non-zero, this will start a refresh go-routine that updates +// the cache every ttl interval. This go-routine will be stopped when the +// attack is stopped. +// When the ttl is negative, no caching will be performed. +func DNSCaching(ttl time.Duration) func(*Attacker) { + return func(a *Attacker) { + if ttl < 0 { + return + } + + if tr, ok := a.client.Transport.(*http.Transport); ok { + dial := tr.DialContext + if dial == nil { + dial = a.dialer.DialContext + } + + resolver := &dnscache.Resolver{} + + if ttl != 0 { + go func() { + refresh := time.NewTicker(ttl) + defer refresh.Stop() + for { + select { + case <-refresh.C: + resolver.Refresh(true) + case <-a.stopch: + return + } + } + }() + } + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + + tr.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + ips, err := resolver.LookupHost(ctx, host) + if err != nil { + return nil, err + } + + if len(ips) == 0 { + return nil, &net.DNSError{Err: "no such host", Name: addr} + } + + // Pick a random IP from each IP family and dial each concurrently. + // The first that succeeds wins, the other gets canceled. + + rng.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] }) + + // In place filtering of ips to only include the first IPv4 and IPv6. + j := 0 + for i := 0; i < len(ips) && j < 2; i++ { + ip := net.ParseIP(ips[i]) + switch { + case len(ip.To4()) == net.IPv4len && j == 0: + fallthrough + case len(ip) == net.IPv6len && j == 1: + ips[j] = ips[i] + j++ + } + } + ips = ips[:j] + + type result struct { + conn net.Conn + err error + } + + ch := make(chan result, len(ips)) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, ip := range ips { + go func(ip string) { + conn, err := dial(ctx, network, net.JoinHostPort(ip, port)) + ch <- result{conn, err} + }(ip) + } + + for i := 0; i < cap(ch); i++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-ch: + if err = r.err; err != nil { + continue + } + return r.conn, nil + } + } + + return nil, err + } + } + } +} + +type attack struct { + name string + began time.Time + + seqmu sync.Mutex + seq uint64 +} + // Attack reads its Targets from the passed Targeter and attacks them at // the rate specified by the Pacer. When the duration is zero the attack // runs until Stop is called. Results are sent to the returned channel as soon @@ -275,21 +386,29 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) < workers = a.maxWorkers } + atk := &attack{ + name: name, + began: time.Now(), + } + results := make(chan *Result) ticks := make(chan struct{}) for i := uint64(0); i < workers; i++ { wg.Add(1) - go a.attack(tr, name, &wg, ticks, results) + go a.attack(tr, atk, &wg, ticks, results) } go func() { - defer close(results) - defer wg.Wait() - defer close(ticks) - - began, count := time.Now(), uint64(0) + defer func() { + close(ticks) + wg.Wait() + close(results) + a.Stop() + }() + + count := uint64(0) for { - elapsed := time.Since(began) + elapsed := time.Since(atk.began) if du > 0 && elapsed > du { return } @@ -312,7 +431,7 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) < // all workers are blocked. start one more and try again workers++ wg.Add(1) - go a.attack(tr, name, &wg, ticks, results) + go a.attack(tr, atk, &wg, ticks, results) } } @@ -342,25 +461,25 @@ func (a *Attacker) Stop() bool { } } -func (a *Attacker) attack(tr Targeter, name string, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) { +func (a *Attacker) attack(tr Targeter, atk *attack, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) { defer workers.Done() for range ticks { - results <- a.hit(tr, name) + results <- a.hit(tr, atk) } } -func (a *Attacker) hit(tr Targeter, name string) *Result { +func (a *Attacker) hit(tr Targeter, atk *attack) *Result { var ( - res = Result{Attack: name} + res = Result{Attack: atk.name} tgt Target err error ) - a.seqmu.Lock() - res.Timestamp = a.began.Add(time.Since(a.began)) - res.Seq = a.seq - a.seq++ - a.seqmu.Unlock() + atk.seqmu.Lock() + res.Timestamp = atk.began.Add(time.Since(atk.began)) + res.Seq = atk.seq + atk.seq++ + atk.seqmu.Unlock() defer func() { res.Latency = time.Since(res.Timestamp) @@ -382,8 +501,8 @@ func (a *Attacker) hit(tr Targeter, name string) *Result { return &res } - if name != "" { - req.Header.Set("X-Vegeta-Attack", name) + if atk.name != "" { + req.Header.Set("X-Vegeta-Attack", atk.name) } req.Header.Set("X-Vegeta-Seq", strconv.FormatUint(res.Seq, 10)) diff --git a/lib/attack_test.go b/lib/attack_test.go index 29f9a498..3495bc6b 100644 --- a/lib/attack_test.go +++ b/lib/attack_test.go @@ -82,7 +82,7 @@ func TestRedirects(t *testing.T) { redirects := 2 atk := NewAttacker(Redirects(redirects)) tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - res := atk.hit(tr, "") + res := atk.hit(tr, &attack{name: "", began: time.Now()}) want := fmt.Sprintf("stopped after %d redirects", redirects) if got := res.Error; !strings.HasSuffix(got, want) { t.Fatalf("want: '%v' in '%v'", want, got) @@ -99,7 +99,7 @@ func TestNoFollow(t *testing.T) { defer server.Close() atk := NewAttacker(Redirects(NoFollow)) tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - res := atk.hit(tr, "") + res := atk.hit(tr, &attack{name: "", began: time.Now()}) if res.Error != "" { t.Fatalf("got err: %v", res.Error) } @@ -118,7 +118,7 @@ func TestTimeout(t *testing.T) { defer server.Close() atk := NewAttacker(Timeout(10 * time.Millisecond)) tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - res := atk.hit(tr, "") + res := atk.hit(tr, &attack{name: "", began: time.Now()}) want := "Client.Timeout exceeded while awaiting headers" if got := res.Error; !strings.Contains(got, want) { @@ -148,7 +148,8 @@ func TestLocalAddr(t *testing.T) { defer server.Close() atk := NewAttacker(LocalAddr(*addr)) tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - atk.hit(tr, "") + atk.hit(tr, &attack{name: "", began: time.Now()}) + } func TestKeepAlive(t *testing.T) { @@ -182,7 +183,8 @@ func TestStatusCodeErrors(t *testing.T) { defer server.Close() atk := NewAttacker() tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - res := atk.hit(tr, "") + + res := atk.hit(tr, &attack{name: "", began: time.Now()}) if got, want := res.Error, "400 Bad Request"; got != want { t.Fatalf("got: %v, want: %v", got, want) } @@ -192,7 +194,7 @@ func TestBadTargeterError(t *testing.T) { t.Parallel() atk := NewAttacker() tr := func(*Target) error { return io.EOF } - res := atk.hit(tr, "") + res := atk.hit(tr, &attack{name: "", began: time.Now()}) if got, want := res.Error, io.EOF.Error(); got != want { t.Fatalf("got: %v, want: %v", got, want) } @@ -210,7 +212,8 @@ func TestResponseBodyCapture(t *testing.T) { defer server.Close() atk := NewAttacker() tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - res := atk.hit(tr, "") + + res := atk.hit(tr, &attack{name: "", began: time.Now()}) if got := res.Body; !bytes.Equal(got, want) { t.Fatalf("got: %v, want: %v", got, want) } @@ -237,7 +240,7 @@ func TestProxyOption(t *testing.T) { })) tr := NewStaticTargeter(Target{Method: "GET", URL: "http://127.0.0.2"}) - res := atk.hit(tr, "") + res := atk.hit(tr, &attack{name: "", began: time.Now()}) if got, want := res.Error, ""; got != want { t.Errorf("got error: %q, want %q", got, want) } @@ -263,7 +266,7 @@ func TestMaxBody(t *testing.T) { t.Run(fmt.Sprint(maxBody), func(t *testing.T) { atk := NewAttacker(MaxBody(maxBody)) tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - res := atk.hit(tr, "") + res := atk.hit(tr, &attack{name: "", began: time.Now()}) want := body if maxBody >= 0 { @@ -321,7 +324,7 @@ func TestUnixSocket(t *testing.T) { atk := NewAttacker(UnixSocket(socketFile)) tr := NewStaticTargeter(Target{Method: "GET", URL: "http://anyserver/"}) - res := atk.hit(tr, "") + res := atk.hit(tr, &attack{name: "", began: time.Now()}) if !bytes.Equal(res.Body, body) { t.Fatalf("got: %s, want: %s", string(res.Body), string(body)) } @@ -355,7 +358,7 @@ func TestClient(t *testing.T) { tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) atk := NewAttacker(Client(client)) - resp := atk.hit(tr, "TEST") + resp := atk.hit(tr, &attack{name: "TEST", began: time.Now()}) if !strings.Contains(resp.Error, "Client.Timeout exceeded while awaiting headers") { t.Errorf("Expected timeout error") } @@ -373,18 +376,17 @@ func TestVegetaHeaders(t *testing.T) { defer server.Close() tr := NewStaticTargeter(Target{Method: "GET", URL: server.URL}) - atk := NewAttacker() - + a := NewAttacker() + atk := &attack{name: "ig-bang", began: time.Now()} for seq := 0; seq < 5; seq++ { - attack := "big-bang" - res := atk.hit(tr, attack) + res := a.hit(tr, atk) var hdr http.Header if err := json.Unmarshal(res.Body, &hdr); err != nil { t.Fatal(err) } - if have, want := hdr.Get("X-Vegeta-Attack"), attack; have != want { + if have, want := hdr.Get("X-Vegeta-Attack"), atk.name; have != want { t.Errorf("X-Vegeta-Attack: have %q, want %q", have, want) }