From 93b94b06a1c82615a6afb945d8ee84e8a69960c3 Mon Sep 17 00:00:00 2001 From: nanjingfm <765390765@qq.com> Date: Wed, 20 Jul 2022 13:25:50 +0800 Subject: [PATCH] chore: fix result order (#227) Co-authored-by: fm --- parallel/page.go | 91 ++++++++++++++++++++ parallel/page_test.go | 187 ++++++++++++++++++++++++++++++++++++++++++ parallel/task.go | 36 +++++--- parallel/task_test.go | 15 +++- 4 files changed, 315 insertions(+), 14 deletions(-) create mode 100644 parallel/page.go create mode 100644 parallel/page_test.go diff --git a/parallel/page.go b/parallel/page.go new file mode 100644 index 00000000..c7fe3d59 --- /dev/null +++ b/parallel/page.go @@ -0,0 +1,91 @@ +/* +Copyright 2022 The Katanomi Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package parallel + +import ( + "context" + "fmt" + "time" + + "k8s.io/utils/trace" + "knative.dev/pkg/logging" +) + +// PageRequestFunc is a tool for concurrent processing of pagination +type PageRequestFunc struct { + // RequestPage for concurrent request paging + RequestPage func(ctx context.Context, pageSize int, page int) (interface{}, error) + // PageResult for get paging information + PageResult func(items interface{}) (total int, currentPageLen int, err error) +} + +// PageRequest is concurrent request paging +func PageRequest(ctx context.Context, logName string, concurrency int, pageSize int, f PageRequestFunc) ([]interface{}, error) { + log := trace.New("PageRequest", trace.Field{Key: "name", Value: logName}) + logger := logging.FromContext(ctx) + + defer func() { + log.LogIfLong(3 * time.Second) + }() + + items, err := f.RequestPage(ctx, pageSize, 1) + if err != nil { + return nil, err + } + log.Step("requested page 1") + total, firstPageLen, err := f.PageResult(items) + if err != nil { + return nil, err + } + if firstPageLen < pageSize { + return []interface{}{items}, nil + } + + if total == firstPageLen { + return []interface{}{items}, nil + } + + var request = func(i int) func() (interface{}, error) { + return func() (interface{}, error) { + items, err := f.RequestPage(ctx, pageSize, i) + log.Step(fmt.Sprintf("requested page %d", i)) + return items, err + } + } + + totalPage := total / pageSize + if total%pageSize != 0 { + totalPage = totalPage + 1 + } + + if totalPage-1 < concurrency { // first page we have requested, so skip first page + concurrency = totalPage - 1 + } + + p := P(logger, "PageRequest").FailFast().SetConcurrent(concurrency).Context(ctx) + for i := 2; i <= totalPage; i++ { + p.Add(request(i)) + } + + results, err := p.Do().Wait() + if err != nil { + return nil, err + } + + return append([]interface{}{items}, results...), nil + +} diff --git a/parallel/page_test.go b/parallel/page_test.go new file mode 100644 index 00000000..ca36ee38 --- /dev/null +++ b/parallel/page_test.go @@ -0,0 +1,187 @@ +/* +Copyright 2022 The Katanomi Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package parallel + +import ( + "context" + "fmt" + "testing" +) + +type mockResult struct { + Total int + Data []string +} + +type mockRequest func(page, pageSize int) mockResult + +func getMockRequestFunc(total int) mockRequest { + return func(page, pageSize int) mockResult { + result := mockResult{Data: make([]string, 0), Total: total} + if total == 0 { + return result + } + pageNum := total / pageSize + for pageNum*pageSize < total { + pageNum = pageNum + 1 + } + + if page > pageNum { + return result + } + + if page*pageSize <= total { + for i := 0; i < pageSize; i++ { + result.Data = append(result.Data, "result") + } + } else { + // processing last page requests + // for example: total 10, pageNum: 4, page: 4, pageSize 3. this page length is 1 + for i := 0; i < (total - (pageSize * (pageNum - 1))); i++ { + result.Data = append(result.Data, "result") + } + } + return result + } +} + +func TestPage(t *testing.T) { + t.Run("simple", func(t *testing.T) { + PageRequest(context.Background(), "TestPageResult", 2, 10, PageRequestFunc{ + RequestPage: func(ctx context.Context, pageSize int, page int) (interface{}, error) { + fmt.Printf("request -> page: %d, pagesize: %d\n", page, pageSize) + return nil, nil + }, + PageResult: func(items interface{}) (total int, currentPageLen int, err error) { + return 8, 6, nil + }, + }) + }) + t.Run("pre check mock fuck", func(t *testing.T) { + total := 100 + mockFunc := getMockRequestFunc(total) + + preCheck1 := mockFunc(1, 10) + // total 100, page 1, pageSize 10. + if len(preCheck1.Data) != 10 || preCheck1.Total != total { + t.Errorf("pre check failed, should be 10, but got %d, get total is %d", len(preCheck1.Data), preCheck1.Total) + } + + preCheck2 := mockFunc(2, 10) + // total 100, page 2, pageSize 10. + if len(preCheck2.Data) != 10 || preCheck2.Total != total { + t.Errorf("pre check failed, should be 10, but got %d, get total is %d", len(preCheck2.Data), preCheck2.Total) + } + + total = 10 + mockFunc = getMockRequestFunc(total) + preCheckEndPage := mockFunc(4, 3) + // total 10, page 4, pageSize 3. should return data num is 1 + if len(preCheckEndPage.Data) != 1 || preCheckEndPage.Total != total { + t.Errorf("pre check failed, should be 1, but got %d, get total is %d", len(preCheckEndPage.Data), preCheckEndPage.Total) + } + + total = 100000 + mockFunc = getMockRequestFunc(total) + preCheckBigNum := mockFunc(10000, 10) + // total 100000, page 10000, pageSize 10. should return data num is 10000 + if len(preCheckBigNum.Data) != 10 || preCheckBigNum.Total != total { + t.Errorf("pre check failed, should be 1, but got %d, get total is %d", len(preCheckBigNum.Data), preCheckBigNum.Total) + } + }) + t.Run("check result", func(t *testing.T) { + total := 999 + mockFunc := getMockRequestFunc(total) + overResult := make([]string, 0) + + items, _ := PageRequest(context.Background(), "CheckPageResult", 2, 10, PageRequestFunc{ + RequestPage: func(ctx context.Context, pageSize int, page int) (interface{}, error) { + tmpResult := mockFunc(page, pageSize) + return tmpResult, nil + }, + PageResult: func(items interface{}) (total int, currentPageLen int, err error) { + tmpResult := items.(mockResult) + return tmpResult.Total, len(tmpResult.Data), nil + }, + }) + + for _, _item := range items { + item := _item.(mockResult) + overResult = append(overResult, item.Data...) + } + + if len(overResult) != total { + t.Errorf("check result failed, should be %d, got %d", total, len(overResult)) + } + }) +} + +func TestPageZero(t *testing.T) { + t.Run("check result 0", func(t *testing.T) { + total := 0 + mockFunc := getMockRequestFunc(total) + overResult := make([]string, 0) + + items, _ := PageRequest(context.Background(), "CheckPageResult", 2, 10, PageRequestFunc{ + RequestPage: func(ctx context.Context, pageSize int, page int) (interface{}, error) { + tmpResult := mockFunc(page, pageSize) + return tmpResult, nil + }, + PageResult: func(items interface{}) (total int, currentPageLen int, err error) { + tmpResult := items.(mockResult) + return tmpResult.Total, len(tmpResult.Data), nil + }, + }) + + for _, _item := range items { + item := _item.(mockResult) + overResult = append(overResult, item.Data...) + } + + if len(overResult) != total { + t.Errorf("check result failed, should be %d, got %d", total, len(overResult)) + } + }) +} + +func TestPageBigNum(t *testing.T) { + t.Run("check result 0", func(t *testing.T) { + total := 10000 + mockFunc := getMockRequestFunc(total) + overResult := make([]string, 0) + + items, _ := PageRequest(context.Background(), "CheckPageResult", 2, 10, PageRequestFunc{ + RequestPage: func(ctx context.Context, pageSize int, page int) (interface{}, error) { + tmpResult := mockFunc(page, pageSize) + return tmpResult, nil + }, + PageResult: func(items interface{}) (total int, currentPageLen int, err error) { + tmpResult := items.(mockResult) + return tmpResult.Total, len(tmpResult.Data), nil + }, + }) + + for _, _item := range items { + item := _item.(mockResult) + overResult = append(overResult, item.Data...) + } + + if len(overResult) != total { + t.Errorf("check result failed, should be %d, got %d", total, len(overResult)) + } + }) +} diff --git a/parallel/task.go b/parallel/task.go index 3fb816d2..abb43c72 100644 --- a/parallel/task.go +++ b/parallel/task.go @@ -20,12 +20,11 @@ package parallel import ( "context" "fmt" + "reflect" "sync" "go.uber.org/zap" - "reflect" - "k8s.io/apimachinery/pkg/util/errors" ) @@ -189,6 +188,9 @@ func (p *ParallelTasks) Do() *ParallelTasks { } }() + p.results = make([]interface{}, len(p.tasks)) + p.errs = make([]error, len(p.tasks)) + for i, task := range p.tasks { if !p.waitThreshold() { return p @@ -213,7 +215,7 @@ func (p *ParallelTasks) Do() *ParallelTasks { if err != nil { p.errLock.Lock() defer p.errLock.Unlock() - p.errs = append(p.errs, err) + p.errs[index] = err if p.Options.FailFast { log.Debugw("fail fast, will cancel", "task-index", index, "result", result) p.Cancel(err) @@ -223,9 +225,7 @@ func (p *ParallelTasks) Do() *ParallelTasks { // error is nil, we should save result if !isNil(result) { - p.resultsLock.Lock() - p.results = append(p.results, result) - p.resultsLock.Unlock() + p.results[index] = result } }(i, task) @@ -266,15 +266,31 @@ func (p *ParallelTasks) Wait() ([]interface{}, error) { p.Log.Debugw("waiting done.") <-p.doneChan p.Log.Debugw("waited done.") + + var ( + results = make([]interface{}, 0) + errs = make([]error, 0) + ) + for _, result := range p.results { + if result != nil { + results = append(results, result) + } + } + for _, err := range p.errs { + if err != nil { + errs = append(errs, err) + } + } + if p.doneError != nil { - return p.results, p.doneError + return results, p.doneError } - if len(p.errs) > 0 { - return p.results, errors.NewAggregate(p.errs) + if len(errs) > 0 { + return results, errors.NewAggregate(errs) } - return p.results, nil + return results, nil } func isNil(i interface{}) bool { diff --git a/parallel/task_test.go b/parallel/task_test.go index 89a3fc60..1590749f 100644 --- a/parallel/task_test.go +++ b/parallel/task_test.go @@ -188,10 +188,10 @@ var _ = Describe("P().Do().Wait()", func() { Expect(errs).To(BeEquivalentTo(context.DeadlineExceeded)) fmt.Printf("%v,%v,%v,", t1Excuted.get(), t2Excuted.get(), t3Excuted.get()) - //up to now, task is not support cancel - //Expect(t1Excuted.executed).To(BeFalse()) - //Expect(t2Excuted.executed).To(BeFalse()) - //Expect(t3Excuted.executed).To(BeFalse()) + // up to now, task is not support cancel + // Expect(t1Excuted.executed).To(BeFalse()) + // Expect(t2Excuted.executed).To(BeFalse()) + // Expect(t3Excuted.executed).To(BeFalse()) Expect(elapsed < 1 && elapsed > 0.1).To(BeTrue()) Expect(len(res)).To(BeEquivalentTo(0)) @@ -239,6 +239,13 @@ var _ = Describe("P().Do().Wait()", func() { Expect(res).To(ContainElement(fmt.Sprintf("task-%d", i))) } }) + + It("should return with same sort", func() { + Expect(errs).To(BeNil()) + for i := 0; i <= 9; i++ { + Expect(res[i]).To(Equal(fmt.Sprintf("task-%d", i+1))) + } + }) }) Context("when set conccurrent", func() {