diff --git a/cmd/run.go b/cmd/run.go index 1aa82196..7cd9e291 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/linuxsuren/api-testing/pkg/limit" "github.com/linuxsuren/api-testing/pkg/render" "github.com/linuxsuren/api-testing/pkg/runner" "github.com/linuxsuren/api-testing/pkg/testing" @@ -23,6 +24,10 @@ type runOption struct { requestIgnoreError bool thread int64 context context.Context + qps int32 + burst int32 + limiter limit.RateLimiter + startTime time.Time } // CreateRunCommand returns the run command @@ -45,12 +50,20 @@ See also https://github.com/LinuxSuRen/api-testing/tree/master/sample`, flags.DurationVarP(&opt.requestTimeout, "request-timeout", "", time.Minute, "Timeout for per request") flags.BoolVarP(&opt.requestIgnoreError, "request-ignore-error", "", false, "Indicate if ignore the request error") flags.Int64VarP(&opt.thread, "thread", "", 1, "Threads of the execution") + flags.Int32VarP(&opt.qps, "qps", "", 5, "QPS") + flags.Int32VarP(&opt.burst, "burst", "", 5, "burst") return } func (o *runOption) runE(cmd *cobra.Command, args []string) (err error) { var files []string + o.startTime = time.Now() o.context = cmd.Context() + o.limiter = limit.NewDefaultRateLimiter(o.qps, o.burst) + defer func() { + cmd.Printf("consume: %s\n", time.Now().Sub(o.startTime).String()) + o.limiter.Stop() + }() if files, err = filepath.Glob(o.pattern); err == nil { for i := range files { @@ -74,12 +87,14 @@ func (o *runOption) runSuiteWithDuration(suite string) (err error) { timeout = time.NewTicker(time.Second) } errChannel := make(chan error, 10*o.thread) + stopSingal := make(chan struct{}, 1) var wait sync.WaitGroup for !stop { select { case <-timeout.C: stop = true + stopSingal <- struct{}{} case err = <-errChannel: if err != nil { stop = true @@ -89,9 +104,6 @@ func (o *runOption) runSuiteWithDuration(suite string) (err error) { continue } wait.Add(1) - if o.duration <= 0 { - stop = true - } go func(ch chan error, sem *semaphore.Weighted) { now := time.Now() @@ -102,16 +114,24 @@ func (o *runOption) runSuiteWithDuration(suite string) (err error) { }() dataContext := getDefaultContext() - ch <- o.runSuite(suite, dataContext, o.context) + ch <- o.runSuite(suite, dataContext, o.context, stopSingal) }(errChannel, sem) + if o.duration <= 0 { + stop = true + } } } - err = <-errChannel + + select { + case err = <-errChannel: + case <-stopSingal: + } + wait.Wait() return } -func (o *runOption) runSuite(suite string, dataContext map[string]interface{}, ctx context.Context) (err error) { +func (o *runOption) runSuite(suite string, dataContext map[string]interface{}, ctx context.Context, stopSingal chan struct{}) (err error) { var testSuite *testing.TestSuite if testSuite, err = testing.Parse(suite); err != nil { return @@ -131,11 +151,23 @@ func (o *runOption) runSuite(suite string, dataContext map[string]interface{}, c testCase.Request.API = fmt.Sprintf("%s%s", testSuite.API, testCase.Request.API) } - setRelativeDir(suite, &testCase) var output interface{} - ctxWithTimeout, _ := context.WithTimeout(ctx, o.requestTimeout) - if output, err = runner.RunTestCase(&testCase, dataContext, ctxWithTimeout); err != nil && !o.requestIgnoreError { + select { + case <-stopSingal: return + default: + // reuse the API prefix + if strings.HasPrefix(testCase.Request.API, "/") { + testCase.Request.API = fmt.Sprintf("%s%s", testSuite.API, testCase.Request.API) + } + + setRelativeDir(suite, &testCase) + o.limiter.Accept() + + ctxWithTimeout, _ := context.WithTimeout(ctx, o.requestTimeout) + if output, err = runner.RunTestCase(&testCase, dataContext, ctxWithTimeout); err != nil && !o.requestIgnoreError { + return + } } dataContext[testCase.Name] = output } diff --git a/cmd/run_test.go b/cmd/run_test.go index 2ae8ccbb..7e738955 100644 --- a/cmd/run_test.go +++ b/cmd/run_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/h2non/gock" + "github.com/linuxsuren/api-testing/pkg/limit" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" ) @@ -48,9 +49,13 @@ func TestRunSuite(t *testing.T) { tt.prepare() ctx := getDefaultContext() - opt := &runOption{requestTimeout: 30 * time.Second} + opt := &runOption{ + requestTimeout: 30 * time.Second, + limiter: limit.NewDefaultRateLimiter(0, 0), + } + stopSingal := make(chan struct{}, 1) - err := opt.runSuite(tt.suiteFile, ctx, context.TODO()) + err := opt.runSuite(tt.suiteFile, ctx, context.TODO(), stopSingal) assert.Equal(t, tt.hasError, err != nil, err) }) } diff --git a/pkg/limit/limiter.go b/pkg/limit/limiter.go new file mode 100644 index 00000000..c71f9b1f --- /dev/null +++ b/pkg/limit/limiter.go @@ -0,0 +1,95 @@ +package limit + +import ( + "sync" + "time" +) + +type RateLimiter interface { + TryAccept() bool + Accept() + Stop() + Burst() int32 +} + +type defaultRateLimiter struct { + qps int32 + burst int32 + lastToken time.Time + singal chan struct{} + mu sync.Mutex +} + +func NewDefaultRateLimiter(qps, burst int32) RateLimiter { + if qps <= 0 { + qps = 5 + } + if burst <= 0 { + burst = 5 + } + limiter := &defaultRateLimiter{ + qps: qps, + burst: burst, + singal: make(chan struct{}, 1), + } + go limiter.updateBurst() + return limiter +} + +func (r *defaultRateLimiter) TryAccept() bool { + _, ok := r.resver() + return ok +} + +func (r *defaultRateLimiter) resver() (delay time.Duration, ok bool) { + delay = time.Now().Sub(r.lastToken) / time.Millisecond + r.lastToken = time.Now() + if delay > 0 { + ok = true + } else if r.Burst() > 0 { + r.Setburst(r.Burst() - 1) + ok = true + } else { + delay = time.Second / time.Duration(r.qps) + } + return +} + +func (r *defaultRateLimiter) Accept() { + delay, ok := r.resver() + if ok { + return + } + + if delay > 0 { + time.Sleep(delay) + } + return +} + +func (r *defaultRateLimiter) Setburst(burst int32) { + r.mu.Lock() + defer r.mu.Unlock() + r.burst = burst +} + +func (r *defaultRateLimiter) Burst() int32 { + r.mu.Lock() + defer r.mu.Unlock() + return r.burst +} + +func (r *defaultRateLimiter) Stop() { + r.singal <- struct{}{} +} + +func (r *defaultRateLimiter) updateBurst() { + for { + select { + case <-time.After(time.Second): + r.Setburst(r.Burst() + r.qps) + case <-r.singal: + return + } + } +} diff --git a/pkg/limit/limiter_test.go b/pkg/limit/limiter_test.go new file mode 100644 index 00000000..a25b17a2 --- /dev/null +++ b/pkg/limit/limiter_test.go @@ -0,0 +1,27 @@ +package limit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestXxx(t *testing.T) { + limiter := NewDefaultRateLimiter(1, 1) + num := 0 + + loop := true + go func(l RateLimiter) { + for loop { + l.Accept() + num += 1 + } + }(limiter) + + select { + case <-time.After(time.Second): + loop = false + } + assert.True(t, num <= 10) +}