From d13e09a7f6004670a9e9d852b9201ecf4af795fc Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Thu, 7 Sep 2023 17:52:32 +0800 Subject: [PATCH 01/19] feat: add @preserve comments to preserve the struct from being trimmed (#124) --- tool/trimmer/main.go | 5 +- tool/trimmer/test_cases/sample1.thrift | 1 + tool/trimmer/test_cases/tests/example2.thrift | 2 +- tool/trimmer/trim/mark.go | 50 ++++++++++++++++++- tool/trimmer/trim/traversal.go | 15 +++++- 5 files changed, 66 insertions(+), 7 deletions(-) diff --git a/tool/trimmer/main.go b/tool/trimmer/main.go index 61fd96cd..67f53430 100644 --- a/tool/trimmer/main.go +++ b/tool/trimmer/main.go @@ -20,13 +20,12 @@ import ( "path/filepath" "strings" + "github.com/cloudwego/thriftgo/generator" "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/semantic" "github.com/cloudwego/thriftgo/tool/trimmer/dump" "github.com/cloudwego/thriftgo/tool/trimmer/trim" "github.com/cloudwego/thriftgo/version" - - "github.com/cloudwego/thriftgo/generator" ) var ( @@ -99,7 +98,7 @@ func main() { os.Exit(2) } } else { - println("-o should be set as a valid dir to enable -r", err.Error()) + println("-o should be set as a valid dir to enable -r") os.Exit(2) } } diff --git a/tool/trimmer/test_cases/sample1.thrift b/tool/trimmer/test_cases/sample1.thrift index 87f55f1e..8ff738a5 100644 --- a/tool/trimmer/test_cases/sample1.thrift +++ b/tool/trimmer/test_cases/sample1.thrift @@ -53,6 +53,7 @@ enum Gender { FEMALE (key = "1", key = "2", key2 = "v2") } (a = "b") +#@PRESERvE struct Address { 1: required string(key = "v") street 2: required string city diff --git a/tool/trimmer/test_cases/tests/example2.thrift b/tool/trimmer/test_cases/tests/example2.thrift index 68851e7d..c8fae9df 100644 --- a/tool/trimmer/test_cases/tests/example2.thrift +++ b/tool/trimmer/test_cases/tests/example2.thrift @@ -22,6 +22,6 @@ struct E{ } -service MyService{ +service MyService extends sample1.EmployeeService{ string A(1:required E req,2:required test.TestStruct t) } \ No newline at end of file diff --git a/tool/trimmer/trim/mark.go b/tool/trimmer/trim/mark.go index bafa990c..98d46346 100644 --- a/tool/trimmer/trim/mark.go +++ b/tool/trimmer/trim/mark.go @@ -49,7 +49,6 @@ func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename t.marks[filename][svc] = true t.markFunction(function, ast, filename) t.trimMethodValid[i] = true - continue } } continue @@ -57,6 +56,10 @@ func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename t.markFunction(function, ast, filename) } + if t.trimMethods != nil && (svc.Extends != "" || svc.Reference != nil) { + t.traceExtendMethod(svc, svc, ast, filename) + } + if svc.Extends != "" && t.marks[filename][svc] { // handle extension if svc.Reference != nil { @@ -169,3 +172,48 @@ func (t *Trimmer) markTypeDef(theType *parser.Type, ast *parser.Thrift, filename } } } + +// for -m, trace the extends and find specified method to base on +func (t *Trimmer) traceExtendMethod(father, svc *parser.Service, ast *parser.Thrift, filename string) (ret bool) { + for _, function := range svc.Functions { + funcName := father.Name + "." + function.Name + for i, method := range t.trimMethods { + if funcName == method { + t.marks[filename][svc] = true + t.markFunction(function, ast, filename) + t.trimMethodValid[i] = true + ret = true + } + } + } + if svc.Extends != "" { + var nextSvc *parser.Service + var nextAst *parser.Thrift + if svc.Reference == nil { + for i, extend := range ast.Services { + if extend.Name == svc.Extends { + nextSvc = ast.Services[i] + nextAst = ast + break + } + } + } else { + for i, extend := range ast.Includes[svc.Reference.Index].Reference.Services { + if extend.Name == svc.Reference.Name { + nextSvc = ast.Includes[svc.Reference.Index].Reference.Services[i] + nextAst = ast.Includes[svc.Reference.Index].Reference + break + } + } + } + back := t.traceExtendMethod(father, nextSvc, nextAst, filename) + ret = back || ret + } + if ret { + t.marks[filename][svc] = true + if svc.Reference != nil { + t.marks[filename][ast.Includes[svc.Reference.Index]] = true + } + } + return ret +} diff --git a/tool/trimmer/trim/traversal.go b/tool/trimmer/trim/traversal.go index 3a7fbadd..8a8a68c2 100644 --- a/tool/trimmer/trim/traversal.go +++ b/tool/trimmer/trim/traversal.go @@ -14,7 +14,12 @@ package trim -import "github.com/cloudwego/thriftgo/parser" +import ( + "regexp" + "strings" + + "github.com/cloudwego/thriftgo/parser" +) // traverse and remove the unmarked part of ast func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { @@ -30,7 +35,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { var listStruct []*parser.StructLike for i := range ast.Structs { - if t.marks[filename][ast.Structs[i]] { + if t.marks[filename][ast.Structs[i]] || checkPreserve(ast.Structs[i]) { listStruct = append(listStruct, ast.Structs[i]) } } @@ -69,3 +74,9 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { } ast.Services = listService } + +func checkPreserve(theStruct *parser.StructLike) bool { + pattern := `^[\s]*(\/\/|#)[\s]*@preserve[\s]*$` + regex := regexp.MustCompile(pattern) + return regex.MatchString(strings.ToLower(theStruct.ReservedComments)) +} From 086fb65a75960057d9f1568dbf8e31ffb637f44e Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Wed, 20 Sep 2023 11:15:00 +0800 Subject: [PATCH 02/19] fix: fix @preserve comment bug (#127) --- generator/golang/backend.go | 2 +- tool/trimmer/args.go | 5 ++++ tool/trimmer/dirTree.go | 7 +++--- tool/trimmer/main.go | 12 ++++++++- tool/trimmer/test_cases/sample1.thrift | 2 +- tool/trimmer/trim/mark.go | 34 +++++++++++++++++++++++++- tool/trimmer/trim/traversal.go | 11 +-------- tool/trimmer/trim/trimmer.go | 9 +++++-- 8 files changed, 63 insertions(+), 19 deletions(-) diff --git a/generator/golang/backend.go b/generator/golang/backend.go index 7858579e..d96e843d 100644 --- a/generator/golang/backend.go +++ b/generator/golang/backend.go @@ -86,7 +86,7 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R g.log = log g.prepareUtilities() if g.utils.Features().TrimIDL { - err := trim.TrimAST(req.AST, nil) + err := trim.TrimAST(req.AST, nil, false) if err != nil { g.log.Warn("trim error:", err.Error()) } diff --git a/tool/trimmer/args.go b/tool/trimmer/args.go index 64338ab0..eee0f637 100644 --- a/tool/trimmer/args.go +++ b/tool/trimmer/args.go @@ -45,6 +45,7 @@ type Arguments struct { IDL string Recurse string Methods StringSlice + Preserve string } // BuildFlags initializes command line flags. @@ -62,6 +63,9 @@ func (a *Arguments) BuildFlags() *flag.FlagSet { f.Var(&a.Methods, "m", "") f.Var(&a.Methods, "method", "") + f.StringVar(&a.Preserve, "p", "true", "") + f.StringVar(&a.Preserve, "preserve", "true", "") + f.Usage = help return f } @@ -95,6 +99,7 @@ Options: -o, --out [file/dir] Specify the output IDL file/dir. -r, --recurse [dir] Specify a root dir and dump the included IDL recursively beneath the given root. -o should be set as a directory. -m, --method [service.method] Only keep the specified methods and their dependents. Accept multiple -m. + -p, --preserve [true/false] Set to false to ignore @preserve comments `) // print backend options for _, b := range g.AllBackend() { diff --git a/tool/trimmer/dirTree.go b/tool/trimmer/dirTree.go index 3844a83f..97d3e0f2 100644 --- a/tool/trimmer/dirTree.go +++ b/tool/trimmer/dirTree.go @@ -15,6 +15,7 @@ package main import ( + "errors" "fmt" "os" "path/filepath" @@ -87,8 +88,8 @@ func isDirectoryEmpty(path string) (bool, error) { return false, nil } - if len(err.Error()) > len("EOF") { - return false, err + if errors.Is(err, os.ErrNotExist) { + return true, nil } - return true, nil + return false, err } diff --git a/tool/trimmer/main.go b/tool/trimmer/main.go index 67f53430..63c65741 100644 --- a/tool/trimmer/main.go +++ b/tool/trimmer/main.go @@ -18,6 +18,7 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" "github.com/cloudwego/thriftgo/generator" @@ -52,6 +53,15 @@ func main() { os.Exit(0) } + preserve := true + if a.Preserve != "" { + preserve, err = strconv.ParseBool(a.Preserve) + if err != nil { + help() + os.Exit(2) + } + } + // parse file to ast ast, err := parser.ParseFile(a.IDL, nil, true) check(err) @@ -64,7 +74,7 @@ func main() { check(semantic.ResolveSymbols(ast)) // trim ast - check(trim.TrimAST(ast, a.Methods)) + check(trim.TrimAST(ast, a.Methods, !preserve)) // dump the trimmed ast to idl idl, err := dump.DumpIDL(ast) diff --git a/tool/trimmer/test_cases/sample1.thrift b/tool/trimmer/test_cases/sample1.thrift index 8ff738a5..a80ee282 100644 --- a/tool/trimmer/test_cases/sample1.thrift +++ b/tool/trimmer/test_cases/sample1.thrift @@ -53,7 +53,6 @@ enum Gender { FEMALE (key = "1", key = "2", key2 = "v2") } (a = "b") -#@PRESERvE struct Address { 1: required string(key = "v") street 2: required string city @@ -61,6 +60,7 @@ struct Address { 4: required a country } +// @pResErve struct Company { 1: required string name 2: optional Address address diff --git a/tool/trimmer/trim/mark.go b/tool/trimmer/trim/mark.go index 98d46346..5c7bc48e 100644 --- a/tool/trimmer/trim/mark.go +++ b/tool/trimmer/trim/mark.go @@ -14,7 +14,11 @@ package trim -import "github.com/cloudwego/thriftgo/parser" +import ( + "strings" + + "github.com/cloudwego/thriftgo/parser" +) // mark the used part of ast func (t *Trimmer) markAST(ast *parser.Thrift) { @@ -30,6 +34,26 @@ func (t *Trimmer) markAST(ast *parser.Thrift) { for _, typedef := range ast.Typedefs { t.markTypeDef(typedef.Type, ast, ast.Filename) } + + if !t.forceTrimming { + for _, str := range ast.Structs { + if !t.marks[ast.Filename][str] && t.checkPreserve(str) { + t.markStructLike(str, ast, ast.Filename) + } + } + + for _, str := range ast.Unions { + if !t.marks[ast.Filename][str] && t.checkPreserve(str) { + t.markStructLike(str, ast, ast.Filename) + } + } + + for _, str := range ast.Exceptions { + if !t.marks[ast.Filename][str] && t.checkPreserve(str) { + t.markStructLike(str, ast, ast.Filename) + } + } + } } func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename string) { @@ -217,3 +241,11 @@ func (t *Trimmer) traceExtendMethod(father, svc *parser.Service, ast *parser.Thr } return ret } + +// check for @Preserve comments +func (t *Trimmer) checkPreserve(theStruct *parser.StructLike) bool { + if t.forceTrimming { + return false + } + return t.preserveRegex.MatchString(strings.ToLower(theStruct.ReservedComments)) +} diff --git a/tool/trimmer/trim/traversal.go b/tool/trimmer/trim/traversal.go index 8a8a68c2..e945e8f7 100644 --- a/tool/trimmer/trim/traversal.go +++ b/tool/trimmer/trim/traversal.go @@ -15,9 +15,6 @@ package trim import ( - "regexp" - "strings" - "github.com/cloudwego/thriftgo/parser" ) @@ -35,7 +32,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { var listStruct []*parser.StructLike for i := range ast.Structs { - if t.marks[filename][ast.Structs[i]] || checkPreserve(ast.Structs[i]) { + if t.marks[filename][ast.Structs[i]] || t.checkPreserve(ast.Structs[i]) { listStruct = append(listStruct, ast.Structs[i]) } } @@ -74,9 +71,3 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { } ast.Services = listService } - -func checkPreserve(theStruct *parser.StructLike) bool { - pattern := `^[\s]*(\/\/|#)[\s]*@preserve[\s]*$` - regex := regexp.MustCompile(pattern) - return regex.MatchString(strings.ToLower(theStruct.ReservedComments)) -} diff --git a/tool/trimmer/trim/trimmer.go b/tool/trimmer/trim/trimmer.go index c6df09b8..bd4ee050 100644 --- a/tool/trimmer/trim/trimmer.go +++ b/tool/trimmer/trim/trimmer.go @@ -17,6 +17,7 @@ package trim import ( "fmt" "os" + "regexp" "strings" "github.com/cloudwego/thriftgo/parser" @@ -33,10 +34,12 @@ type Trimmer struct { // use -m trimMethods []string trimMethodValid []bool + preserveRegex *regexp.Regexp + forceTrimming bool } // TrimAST trim the single AST, pass method names if -m specified -func TrimAST(ast *parser.Thrift, trimMethods []string) error { +func TrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool) error { trimmer, err := newTrimmer(nil, "") if err != nil { return err @@ -44,6 +47,7 @@ func TrimAST(ast *parser.Thrift, trimMethods []string) error { trimmer.asts[ast.Filename] = ast trimmer.trimMethods = trimMethods trimmer.trimMethodValid = make([]bool, len(trimMethods)) + trimmer.forceTrimming = forceTrimming for i, method := range trimMethods { parts := strings.Split(method, ".") if len(parts) < 2 { @@ -110,7 +114,8 @@ func newTrimmer(files []string, outDir string) (*Trimmer, error) { } trimmer.asts = make(map[string]*parser.Thrift) trimmer.marks = make(map[string]map[interface{}]bool) - + pattern := `^[\s]*(\/\/|#)[\s]*@preserve[\s]*$` + trimmer.preserveRegex = regexp.MustCompile(pattern) return trimmer, nil } From a959e9fb9fb5b39f6f91f281f57477d00a3f53a6 Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Mon, 25 Sep 2023 16:26:29 +0800 Subject: [PATCH 03/19] Fix: fix dir tree traversal bug (#128) --- tool/trimmer/dirTree.go | 40 +++++-------- tool/trimmer/dirTree_test.go | 58 +++++++++++++++++++ tool/trimmer/main.go | 5 ++ .../test_cases/tests/dir/dir2/test.thrift | 20 +++++++ tool/trimmer/trim/trimmer_test.go | 4 +- 5 files changed, 98 insertions(+), 29 deletions(-) create mode 100644 tool/trimmer/dirTree_test.go create mode 100644 tool/trimmer/test_cases/tests/dir/dir2/test.thrift diff --git a/tool/trimmer/dirTree.go b/tool/trimmer/dirTree.go index 97d3e0f2..c7c39683 100644 --- a/tool/trimmer/dirTree.go +++ b/tool/trimmer/dirTree.go @@ -28,13 +28,13 @@ func createDirTree(sourceDir, destinationDir string) { return err } if info.IsDir() { - newDir := filepath.Join(destinationDir, path[len(sourceDir):]) if path[len(sourceDir)-1] != filepath.Separator { - newDir = filepath.Join(destinationDir, path[len(sourceDir)-1:]) + path = path + string(filepath.Separator) } + newDir := filepath.Join(destinationDir, path[len(sourceDir):]) err := os.MkdirAll(newDir, os.ModePerm) if err != nil { - return err + return errors.New("create dir tree error:" + err.Error()) } } return nil @@ -47,32 +47,18 @@ func createDirTree(sourceDir, destinationDir string) { // remove empty directory of output dir-tree func removeEmptyDir(source string) { - err := filepath.Walk(source, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if info.IsDir() { - empty, err := isDirectoryEmpty(path) - if err != nil { - return err - } - if empty { - err := os.Remove(path) - if err != nil { - return err - } - } + files, err := os.ReadDir(source) + if err != nil { + return + } + for _, file := range files { + if file.IsDir() { + removeEmptyDir(source + string(filepath.Separator) + file.Name()) } - return nil - }) - - parent := filepath.Dir(source) - if parent != source { - removeEmptyDir(parent) } - - if err != nil { - fmt.Printf("Error: %v\n", err) + empty, err := isDirectoryEmpty(source) + if empty || err != nil { + _ = os.RemoveAll(source) } } diff --git a/tool/trimmer/dirTree_test.go b/tool/trimmer/dirTree_test.go new file mode 100644 index 00000000..d856a7a4 --- /dev/null +++ b/tool/trimmer/dirTree_test.go @@ -0,0 +1,58 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "os" + "path/filepath" + "testing" + + "github.com/cloudwego/thriftgo/pkg/test" +) + +func TestDirTree(t *testing.T) { + _ = os.RemoveAll("trimmer_test") + createDirTree("test_cases", "trimmer_test") + fileCount, dirCount, err := countFilesAndSubdirectories("trimmer_test") + test.Assert(t, err == nil) + test.Assert(t, fileCount == 0) + test.Assert(t, dirCount == 3) + removeEmptyDir("trimmer_test") + _, err = os.ReadDir("trimmer_test") + test.Assert(t, err != nil) +} + +func countFilesAndSubdirectories(dirPath string) (int, int, error) { + var fileCount, dirCount int + files, err := os.ReadDir(dirPath) + if err != nil { + return 0, 0, err + } + for _, file := range files { + if file.IsDir() { + dirCount++ + subDirPath := filepath.Join(dirPath, file.Name()) + subFileCount, subDirCount, err := countFilesAndSubdirectories(subDirPath) + if err != nil { + return 0, 0, err + } + fileCount += subFileCount + dirCount += subDirCount + } else { + fileCount++ + } + } + return fileCount, dirCount, nil +} diff --git a/tool/trimmer/main.go b/tool/trimmer/main.go index 63c65741..b955fda0 100644 --- a/tool/trimmer/main.go +++ b/tool/trimmer/main.go @@ -112,6 +112,11 @@ func main() { os.Exit(2) } } + relPath, err := filepath.Rel(a.Recurse, a.OutputFile) + if err == nil && (len(relPath) < 2 || relPath[:2] != "..") { + println("output-dir should be set outside of -r base-dir to avoid overlay") + os.Exit(2) + } createDirTree(a.Recurse, a.OutputFile) recurseDump(ast, a.Recurse, a.OutputFile) relativePath, err := filepath.Rel(a.Recurse, a.IDL) diff --git a/tool/trimmer/test_cases/tests/dir/dir2/test.thrift b/tool/trimmer/test_cases/tests/dir/dir2/test.thrift new file mode 100644 index 00000000..a7359c51 --- /dev/null +++ b/tool/trimmer/test_cases/tests/dir/dir2/test.thrift @@ -0,0 +1,20 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "../../../sample1.thrift" + +// @preserve +struct TestStruct{ + 1: sample1.Person person +} \ No newline at end of file diff --git a/tool/trimmer/trim/trimmer_test.go b/tool/trimmer/trim/trimmer_test.go index 007dc4d5..d7f6187f 100644 --- a/tool/trimmer/trim/trimmer_test.go +++ b/tool/trimmer/trim/trimmer_test.go @@ -25,12 +25,12 @@ import ( ) func TestTrimmer(t *testing.T) { - t.Run("trim AST-case 1", testCase1) + t.Run("trim AST", testSingleFile) // t.Run("trim AST - test many", testMany) } // test single file ast trimming -func testCase1(t *testing.T) { +func testSingleFile(t *testing.T) { trimmer, err := newTrimmer(nil, "") test.Assert(t, err == nil, err) filename := filepath.Join("..", "test_cases", "sample1.thrift") From 0a5410212ed1cfb928d4ee7a8ebc5bc3147e6ea6 Mon Sep 17 00:00:00 2001 From: Li2CO3 <45219850+HeyJavaBean@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:12:38 +0800 Subject: [PATCH 04/19] test: add codegen test (#129) --- .github/workflows/push-check.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/.github/workflows/push-check.yml b/.github/workflows/push-check.yml index 9033534b..b34c5e61 100644 --- a/.github/workflows/push-check.yml +++ b/.github/workflows/push-check.yml @@ -42,3 +42,31 @@ jobs: echo "test done!" go vet -stdmethods=false $(go list ./...) echo "go vet done!" + + codegen-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '1.17' + - name: Prepare + run: | + go install + go install github.com/cloudwego/kitex/tool/cmd/kitex@develop + LOCAL_REPO=$(pwd) + cd .. + git clone https://github.com/cloudwego/kitex-tests.git + cd kitex-tests/codegen + go mod init codegen-test + go mod edit -replace=github.com/apache/thrift=github.com/apache/thrift@v0.13.0 + go mod edit -replace github.com/cloudwego/thriftgo=${LOCAL_REPO} + go mod tidy + bash -version + bash ./codegen_install_check.sh + - name: CodeGen + run: | + cd ../kitex-tests/codegen + tree + bash ./codegen_run.sh From 4dc2032952dfc9cc57cf970d742533f02adeb2a7 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Sun, 8 Oct 2023 11:33:03 +0800 Subject: [PATCH 05/19] optimize: support enum tail comments (#130) --- parser/parser.go | 15 +++++++++++++-- parser/parser_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ parser/thrift.peg | 2 +- parser/thrift.peg.go | 5 ++++- 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 2ba21f58..00bf1317 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -510,7 +510,7 @@ func (p *parser) parseEnum(node *node32) (err error) { if err != nil { return err } - // ENUM Identifier LWING ( ReservedComments Identifier (EQUAL IntConstant)? Annotations? ListSeparator?)* RWING + // ENUM Identifier LWING ( ReservedComments Identifier (EQUAL IntConstant)? Annotations? ListSeparator? ReservedEndLineComments SkipLine)* RWING node = node.next // ignore ENUM name := p.pegText(node) var values []*EnumValue @@ -539,7 +539,18 @@ func (p *parser) parseEnum(node *node32) (err error) { } if n.next.pegRule == ruleAnnotations { - v.Annotations, err = p.parseAnnotations(n.next) + n = n.next + v.Annotations, err = p.parseAnnotations(n) + if err != nil { + return err + } + } + if n.next.pegRule == ruleListSeparator { + n = n.next + } + if n.next.pegRule == ruleReservedEndLineComments && v.ReservedComments == "" { + n = n.next + v.ReservedComments, err = p.parseReservedEndLineComments(n) if err != nil { return err } diff --git a/parser/parser_test.go b/parser/parser_test.go index 17393eba..eb882f51 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -206,6 +206,24 @@ struct my_struct { and for another line */ } + +enum my_enum { + // header comment for 1 + e1 + // header comment for 2 + e2 // tail-reserved comment for 2 + e3 // tail-reserved comment for 3 + e4, // tail-reserved comment for 4 + e5(at="annotation") // tail-reserved comment for 5 + // header comment for 6 + e6 + // header comment for 7 + e7 // tail-reserved comment for 7 + e8 # tail-reserved comment for 8 + e9 /* tail-reserved comment for 9 + and for another line + */ +} ` func TestFieldReservedEndLineComment(t *testing.T) { @@ -237,6 +255,31 @@ func TestFieldReservedEndLineComment(t *testing.T) { */`) } } + + for _, f := range ast.Enums[0].Values { + switch f.Name { + case "e1": + test.Assert(t, f.ReservedComments == `// header comment for 1`) + case "e2": + test.Assert(t, f.ReservedComments == `// header comment for 2`) + case "e3": + test.Assert(t, f.ReservedComments == `// tail-reserved comment for 3`) + case "e4": + test.Assert(t, f.ReservedComments == `// tail-reserved comment for 4`) + case "e5": + test.Assert(t, f.ReservedComments == `// tail-reserved comment for 5`) + case "e6": + test.Assert(t, f.ReservedComments == `// header comment for 6`) + case "e7": + test.Assert(t, f.ReservedComments == `// header comment for 7`) + case "e8": + test.Assert(t, f.ReservedComments == `// tail-reserved comment for 8`) + case "e9": + test.Assert(t, f.ReservedComments == `/* tail-reserved comment for 9 + and for another line + */`) + } + } } const testSpaceSkip = ` diff --git a/parser/thrift.peg b/parser/thrift.peg index 45963d9f..d8c685ee 100644 --- a/parser/thrift.peg +++ b/parser/thrift.peg @@ -25,7 +25,7 @@ Const <- CONST FieldType Identifier EQUAL ConstValue ListSeparator? Typedef <- TYPEDEF FieldType Identifier -Enum <- ENUM Identifier LWING (ReservedComments Identifier (EQUAL IntConstant)? Annotations? ListSeparator? SkipLine)* RWING +Enum <- ENUM Identifier LWING (ReservedComments Identifier (EQUAL IntConstant)? Annotations? ListSeparator? ReservedEndLineComments SkipLine)* RWING Service <- SERVICE Identifier ( EXTENDS Identifier )? LWING Function* RWING diff --git a/parser/thrift.peg.go b/parser/thrift.peg.go index dd18089b..cc943377 100644 --- a/parser/thrift.peg.go +++ b/parser/thrift.peg.go @@ -824,7 +824,7 @@ func (p *ThriftIDL) Init(options ...func(*ThriftIDL) error) error { position, tokenIndex = position42, tokenIndex42 return false }, - /* 9 Enum <- <(ENUM Identifier LWING (ReservedComments Identifier (EQUAL IntConstant)? Annotations? ListSeparator? SkipLine)* RWING)> */ + /* 9 Enum <- <(ENUM Identifier LWING (ReservedComments Identifier (EQUAL IntConstant)? Annotations? ListSeparator? ReservedEndLineComments SkipLine)* RWING)> */ func() bool { position44, tokenIndex44 := position, tokenIndex { @@ -880,6 +880,9 @@ func (p *ThriftIDL) Init(options ...func(*ThriftIDL) error) error { position, tokenIndex = position52, tokenIndex52 } l53: + if !_rules[ruleReservedEndLineComments]() { + goto l47 + } if !_rules[ruleSkipLine]() { goto l47 } From 10dfa27e220c07c2cefd0e9c981508ec4b3d83e2 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Tue, 10 Oct 2023 14:25:56 +0800 Subject: [PATCH 06/19] fix: fix no throw generated when function resp is void (#132) --- generator/golang/templates/client.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/generator/golang/templates/client.go b/generator/golang/templates/client.go index 92efed00..3170c219 100644 --- a/generator/golang/templates/client.go +++ b/generator/golang/templates/client.go @@ -86,6 +86,15 @@ func (p *{{$ClientName}}) {{- template "FunctionSignature" . -}} { if err = p.Client_().Call(ctx, "{{.Name}}", &_args, &_result); err != nil { return } + {{- if .Throws}} + switch { + {{- range .Throws}} + case _result.{{($ResType.Field .Name).GoName}} != nil: + return _result.{{($ResType.Field .Name).GoName}} + {{- end}} + } + {{- end}} + {{- end}} return nil {{- else}}{{/* If .Void */}} From 7b4aab3d34494078b1e4284e430beb493f11e424 Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:33:02 +0800 Subject: [PATCH 07/19] fix: fix typedef with reference get-type-fail issue for trimmer tool (#133) --- tool/trimmer/test_cases/sample1.thrift | 4 +- tool/trimmer/test_cases/sample1b.thrift | 4 ++ tool/trimmer/trim/mark.go | 68 ++++++++++++++----------- tool/trimmer/trim/trimmer_test.go | 2 +- 4 files changed, 46 insertions(+), 32 deletions(-) diff --git a/tool/trimmer/test_cases/sample1.thrift b/tool/trimmer/test_cases/sample1.thrift index a80ee282..20ab422f 100644 --- a/tool/trimmer/test_cases/sample1.thrift +++ b/tool/trimmer/test_cases/sample1.thrift @@ -44,6 +44,8 @@ test7 typedef Gender(key="v") MyGender (key = "1", key = "2", key2 = "v2") typedef MyGender MyAnotherGender +typedef sample1b.AnotherException samlpe1bAnotherException +typedef sample1b.NotDirectInclude notDirectInclude typedef i32 a // out enum"ZZ" @@ -113,7 +115,7 @@ service ProjectService { service CompanyService { Company getCompany(1: string id) - void addCompany(1: Company company) throws(1: sample1b.AnotherException exc) + void addCompany(1: Company company) throws(1: samlpe1bAnotherException exc) void updateCompany(1: string id, 2: Company company) list getDepartments(1: string company_id) void anotherUselessMethod(1: MaybeUseless useless) diff --git a/tool/trimmer/test_cases/sample1b.thrift b/tool/trimmer/test_cases/sample1b.thrift index c032ff18..4678cffa 100644 --- a/tool/trimmer/test_cases/sample1b.thrift +++ b/tool/trimmer/test_cases/sample1b.thrift @@ -47,5 +47,9 @@ exception AnotherException{ 1: i32 abc } +exception NotDirectInclude{ + +} + const i32 DEFAULT_CODE = 3000; const string trash_string = "trash!" diff --git a/tool/trimmer/trim/mark.go b/tool/trimmer/trim/mark.go index 5c7bc48e..a660850f 100644 --- a/tool/trimmer/trim/mark.go +++ b/tool/trimmer/trim/mark.go @@ -27,33 +27,7 @@ func (t *Trimmer) markAST(ast *parser.Thrift) { t.markService(service, ast, ast.Filename) } - for _, constant := range ast.Constants { - t.markType(constant.Type, ast, ast.Filename) - } - - for _, typedef := range ast.Typedefs { - t.markTypeDef(typedef.Type, ast, ast.Filename) - } - - if !t.forceTrimming { - for _, str := range ast.Structs { - if !t.marks[ast.Filename][str] && t.checkPreserve(str) { - t.markStructLike(str, ast, ast.Filename) - } - } - - for _, str := range ast.Unions { - if !t.marks[ast.Filename][str] && t.checkPreserve(str) { - t.markStructLike(str, ast, ast.Filename) - } - } - - for _, str := range ast.Exceptions { - if !t.marks[ast.Filename][str] && t.checkPreserve(str) { - t.markStructLike(str, ast, ast.Filename) - } - } - } + t.markKeptParts(ast, ast.Filename) } func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename string) { @@ -89,6 +63,7 @@ func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename if svc.Reference != nil { theInclude := ast.Includes[svc.Reference.Index] t.marks[filename][theInclude] = true + t.markKeptParts(theInclude.Reference, filename) for _, service := range theInclude.Reference.Services { if service.Name == svc.Reference.Name { t.markService(service, theInclude.Reference, filename) @@ -129,6 +104,7 @@ func (t *Trimmer) markType(theType *parser.Type, ast *parser.Thrift, filename st // if referenced, redirect to included ast baseAST = ast.Includes[theType.Reference.Index].Reference t.marks[filename][ast.Includes[theType.Reference.Index]] = true + t.markKeptParts(ast.Includes[theType.Reference.Index].Reference, filename) } if theType.IsTypedef != nil { @@ -186,10 +162,10 @@ func (t *Trimmer) markTypeDef(theType *parser.Type, ast *parser.Thrift, filename return } - for _, typedef := range ast.Typedefs { + for i, typedef := range ast.Typedefs { if typedef.Alias == theType.Name { - if !t.marks[filename][typedef] { - t.marks[filename][typedef] = true + if !t.marks[filename][ast.Typedefs[i]] { + t.marks[filename][ast.Typedefs[i]] = true t.markType(typedef.Type, ast, filename) } return @@ -197,6 +173,37 @@ func (t *Trimmer) markTypeDef(theType *parser.Type, ast *parser.Thrift, filename } } +func (t *Trimmer) markKeptParts(ast *parser.Thrift, filename string) { + for _, constant := range ast.Constants { + t.markType(constant.Type, ast, filename) + } + + for _, typedef := range ast.Typedefs { + t.marks[filename][typedef] = true + t.markType(typedef.Type, ast, filename) + } + + if !t.forceTrimming { + for _, str := range ast.Structs { + if !t.marks[filename][str] && t.checkPreserve(str) { + t.markStructLike(str, ast, filename) + } + } + + for _, str := range ast.Unions { + if !t.marks[filename][str] && t.checkPreserve(str) { + t.markStructLike(str, ast, filename) + } + } + + for _, str := range ast.Exceptions { + if !t.marks[filename][str] && t.checkPreserve(str) { + t.markStructLike(str, ast, filename) + } + } + } +} + // for -m, trace the extends and find specified method to base on func (t *Trimmer) traceExtendMethod(father, svc *parser.Service, ast *parser.Thrift, filename string) (ret bool) { for _, function := range svc.Functions { @@ -237,6 +244,7 @@ func (t *Trimmer) traceExtendMethod(father, svc *parser.Service, ast *parser.Thr t.marks[filename][svc] = true if svc.Reference != nil { t.marks[filename][ast.Includes[svc.Reference.Index]] = true + t.markKeptParts(ast.Includes[svc.Reference.Index].Reference, filename) } } return ret diff --git a/tool/trimmer/trim/trimmer_test.go b/tool/trimmer/trim/trimmer_test.go index d7f6187f..53e7b2fb 100644 --- a/tool/trimmer/trim/trimmer_test.go +++ b/tool/trimmer/trim/trimmer_test.go @@ -49,7 +49,7 @@ func testSingleFile(t *testing.T) { test.Assert(t, len(ast.Structs) == 6) test.Assert(t, len(ast.Includes) == 1) - test.Assert(t, len(ast.Typedefs) == 3) + test.Assert(t, len(ast.Typedefs) == 5) test.Assert(t, len(ast.Namespaces) == 1) test.Assert(t, len(ast.Includes[0].Reference.Structs) == 2) test.Assert(t, len(ast.Includes[0].Reference.Constants) == 2) From 2864cce5508b944f261d2c2786980e96609f28f3 Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:33:42 +0800 Subject: [PATCH 08/19] fix: hot-fix #133 trimmer typedef reference (#134) --- tool/trimmer/trim/mark.go | 24 +++++++++++++++--------- tool/trimmer/trim/pre-process.go | 24 ++++++++++++++++++++++++ tool/trimmer/trim/traversal.go | 1 + 3 files changed, 40 insertions(+), 9 deletions(-) create mode 100644 tool/trimmer/trim/pre-process.go diff --git a/tool/trimmer/trim/mark.go b/tool/trimmer/trim/mark.go index a660850f..55adbb1a 100644 --- a/tool/trimmer/trim/mark.go +++ b/tool/trimmer/trim/mark.go @@ -23,11 +23,12 @@ import ( // mark the used part of ast func (t *Trimmer) markAST(ast *parser.Thrift) { t.marks[ast.Filename] = make(map[interface{}]bool) + t.preProcess(ast, ast.Filename) for _, service := range ast.Services { t.markService(service, ast, ast.Filename) } - t.markKeptParts(ast, ast.Filename) + t.markKeptPart(ast, ast.Filename) } func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename string) { @@ -62,8 +63,7 @@ func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename // handle extension if svc.Reference != nil { theInclude := ast.Includes[svc.Reference.Index] - t.marks[filename][theInclude] = true - t.markKeptParts(theInclude.Reference, filename) + t.markInclude(ast.Includes[svc.Reference.Index], filename) for _, service := range theInclude.Reference.Services { if service.Name == svc.Reference.Name { t.markService(service, theInclude.Reference, filename) @@ -103,8 +103,7 @@ func (t *Trimmer) markType(theType *parser.Type, ast *parser.Thrift, filename st if theType.Reference != nil { // if referenced, redirect to included ast baseAST = ast.Includes[theType.Reference.Index].Reference - t.marks[filename][ast.Includes[theType.Reference.Index]] = true - t.markKeptParts(ast.Includes[theType.Reference.Index].Reference, filename) + t.markInclude(ast.Includes[theType.Reference.Index], filename) } if theType.IsTypedef != nil { @@ -173,13 +172,21 @@ func (t *Trimmer) markTypeDef(theType *parser.Type, ast *parser.Thrift, filename } } -func (t *Trimmer) markKeptParts(ast *parser.Thrift, filename string) { +func (t *Trimmer) markInclude(include *parser.Include, filename string) { + include.Reference.Name2Category = nil + if t.marks[filename][include] { + return + } + t.marks[filename][include] = true + // t.markKeptPart(include.Reference, filename) +} + +func (t *Trimmer) markKeptPart(ast *parser.Thrift, filename string) { for _, constant := range ast.Constants { t.markType(constant.Type, ast, filename) } for _, typedef := range ast.Typedefs { - t.marks[filename][typedef] = true t.markType(typedef.Type, ast, filename) } @@ -243,8 +250,7 @@ func (t *Trimmer) traceExtendMethod(father, svc *parser.Service, ast *parser.Thr if ret { t.marks[filename][svc] = true if svc.Reference != nil { - t.marks[filename][ast.Includes[svc.Reference.Index]] = true - t.markKeptParts(ast.Includes[svc.Reference.Index].Reference, filename) + t.markInclude(ast.Includes[svc.Reference.Index], filename) } } return ret diff --git a/tool/trimmer/trim/pre-process.go b/tool/trimmer/trim/pre-process.go new file mode 100644 index 00000000..0a9ceab1 --- /dev/null +++ b/tool/trimmer/trim/pre-process.go @@ -0,0 +1,24 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trim + +import "github.com/cloudwego/thriftgo/parser" + +func (t *Trimmer) preProcess(ast *parser.Thrift, filename string) { + t.markKeptPart(ast, filename) + for _, include := range ast.Includes { + t.preProcess(include.Reference, filename) + } +} diff --git a/tool/trimmer/trim/traversal.go b/tool/trimmer/trim/traversal.go index e945e8f7..dcd0518b 100644 --- a/tool/trimmer/trim/traversal.go +++ b/tool/trimmer/trim/traversal.go @@ -70,4 +70,5 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { } } ast.Services = listService + ast.Name2Category = nil } From 6a75aa143e83baceb835095654e3723effac892c Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Wed, 18 Oct 2023 20:04:58 +0800 Subject: [PATCH 09/19] fix: fix include semantic check for trimmer (#135) --- tool/trimmer/dirTree_test.go | 2 +- .../test_cases/test_include/example.thrift | 24 +++++++++++++++ .../test_cases/test_include/test.thrift | 26 +++++++++++++++++ tool/trimmer/trim/traversal.go | 3 ++ tool/trimmer/trim/trimmer.go | 1 - tool/trimmer/trim/trimmer_test.go | 29 +++++++++++++++++++ 6 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 tool/trimmer/test_cases/test_include/example.thrift create mode 100644 tool/trimmer/test_cases/test_include/test.thrift diff --git a/tool/trimmer/dirTree_test.go b/tool/trimmer/dirTree_test.go index d856a7a4..7ae23697 100644 --- a/tool/trimmer/dirTree_test.go +++ b/tool/trimmer/dirTree_test.go @@ -28,7 +28,7 @@ func TestDirTree(t *testing.T) { fileCount, dirCount, err := countFilesAndSubdirectories("trimmer_test") test.Assert(t, err == nil) test.Assert(t, fileCount == 0) - test.Assert(t, dirCount == 3) + test.Assert(t, dirCount == 4) removeEmptyDir("trimmer_test") _, err = os.ReadDir("trimmer_test") test.Assert(t, err != nil) diff --git a/tool/trimmer/test_cases/test_include/example.thrift b/tool/trimmer/test_cases/test_include/example.thrift new file mode 100644 index 00000000..e10d837f --- /dev/null +++ b/tool/trimmer/test_cases/test_include/example.thrift @@ -0,0 +1,24 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +namespace go a + +include "test.thrift" + +struct E{ + 1: required test.TestStruct a +} + + diff --git a/tool/trimmer/test_cases/test_include/test.thrift b/tool/trimmer/test_cases/test_include/test.thrift new file mode 100644 index 00000000..6013bbe0 --- /dev/null +++ b/tool/trimmer/test_cases/test_include/test.thrift @@ -0,0 +1,26 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace go c + +struct TestStruct{ + // hello + 1:required string hello + 2:required string id +} + +enum GenderEnum{ + MALE + FEMALE +} diff --git a/tool/trimmer/trim/traversal.go b/tool/trimmer/trim/traversal.go index dcd0518b..ddba0cf8 100644 --- a/tool/trimmer/trim/traversal.go +++ b/tool/trimmer/trim/traversal.go @@ -71,4 +71,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { } ast.Services = listService ast.Name2Category = nil + for _, inc := range ast.Includes { + inc.Used = nil + } } diff --git a/tool/trimmer/trim/trimmer.go b/tool/trimmer/trim/trimmer.go index bd4ee050..54087466 100644 --- a/tool/trimmer/trim/trimmer.go +++ b/tool/trimmer/trim/trimmer.go @@ -61,7 +61,6 @@ func TrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool) error } trimmer.markAST(ast) trimmer.traversal(ast, ast.Filename) - ast.Name2Category = nil if path := parser.CircleDetect(ast); len(path) > 0 { check(fmt.Errorf("found include circle:\n\t%s", path)) } diff --git a/tool/trimmer/trim/trimmer_test.go b/tool/trimmer/trim/trimmer_test.go index 53e7b2fb..e6a17ea6 100644 --- a/tool/trimmer/trim/trimmer_test.go +++ b/tool/trimmer/trim/trimmer_test.go @@ -56,3 +56,32 @@ func testSingleFile(t *testing.T) { test.Assert(t, len(ast.Includes[0].Reference.Services) == 1) test.Assert(t, len(ast.Includes[0].Reference.Namespaces) == 1) } + +func TestInclude(t *testing.T) { + trimmer, err := newTrimmer(nil, "") + test.Assert(t, err == nil, err) + filename := filepath.Join("..", "test_cases/test_include", "example.thrift") + ast, err := parser.ParseFile(filename, []string{"test_cases/test_include"}, true) + check(err) + if path := parser.CircleDetect(ast); len(path) > 0 { + check(fmt.Errorf("found include circle:\n\t%s", path)) + } + checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) + _, err = checker.CheckAll(ast) + check(err) + check(semantic.ResolveSymbols(ast)) + trimmer.asts[filename] = ast + trimmer.markAST(ast) + trimmer.traversal(ast, ast.Filename) + if path := parser.CircleDetect(ast); len(path) > 0 { + check(fmt.Errorf("found include circle:\n\t%s", path)) + } + checker = semantic.NewChecker(semantic.Options{FixWarnings: true}) + _, err = checker.CheckAll(ast) + check(err) + check(semantic.ResolveSymbols(ast)) + + test.Assert(t, len(ast.Structs) == 0) + test.Assert(t, len(ast.Includes) == 1) + test.Assert(t, ast.Includes[0].Used == nil) +} From 4f2620a5ffe553ef720030f462a9de5b0af51c00 Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Fri, 20 Oct 2023 13:21:06 +0800 Subject: [PATCH 10/19] chore: expand @preserve regexp for trimmer (#136) --- tool/trimmer/test_cases/sample1.thrift | 7 +++++++ tool/trimmer/trim/trimmer.go | 2 +- tool/trimmer/trim/trimmer_test.go | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tool/trimmer/test_cases/sample1.thrift b/tool/trimmer/test_cases/sample1.thrift index 20ab422f..8c790839 100644 --- a/tool/trimmer/test_cases/sample1.thrift +++ b/tool/trimmer/test_cases/sample1.thrift @@ -100,6 +100,13 @@ struct Simple { // should not appear struct MaybeUseless{ } +// some comments +# @preServe +// some others +struct preserved{ + +} + service EmployeeService extends sample1b.GetPerson { Employee getEmployee(1: string id) void addEmployee(1: Employee employee) diff --git a/tool/trimmer/trim/trimmer.go b/tool/trimmer/trim/trimmer.go index 54087466..39a143f2 100644 --- a/tool/trimmer/trim/trimmer.go +++ b/tool/trimmer/trim/trimmer.go @@ -113,7 +113,7 @@ func newTrimmer(files []string, outDir string) (*Trimmer, error) { } trimmer.asts = make(map[string]*parser.Thrift) trimmer.marks = make(map[string]map[interface{}]bool) - pattern := `^[\s]*(\/\/|#)[\s]*@preserve[\s]*$` + pattern := `(?m)^[\s]*(\/\/|#)[\s]*@preserve[\s]*$` trimmer.preserveRegex = regexp.MustCompile(pattern) return trimmer, nil } diff --git a/tool/trimmer/trim/trimmer_test.go b/tool/trimmer/trim/trimmer_test.go index e6a17ea6..5687a2d1 100644 --- a/tool/trimmer/trim/trimmer_test.go +++ b/tool/trimmer/trim/trimmer_test.go @@ -47,7 +47,7 @@ func testSingleFile(t *testing.T) { trimmer.markAST(ast) trimmer.traversal(ast, ast.Filename) - test.Assert(t, len(ast.Structs) == 6) + test.Assert(t, len(ast.Structs) == 7) test.Assert(t, len(ast.Includes) == 1) test.Assert(t, len(ast.Typedefs) == 5) test.Assert(t, len(ast.Namespaces) == 1) From 7676d6208e0f9863229e1e549369d8caa2b2136c Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:31:16 +0800 Subject: [PATCH 11/19] feat: generate enum marshal and unmarshal separately (#137) --- generator/golang/option.go | 6 +++++- generator/golang/templates/enum.go | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/generator/golang/option.go b/generator/golang/option.go index 0e42aa6b..2e32e801 100644 --- a/generator/golang/option.go +++ b/generator/golang/option.go @@ -24,7 +24,9 @@ import ( // Features controls the behavior of CodeUtils. type Features struct { - MarshalEnumToText bool `json_enum_as_text:"Generate MarshalText for enum values"` + MarshalEnumToText bool `json_enum_as_text:"Generate MarshalText and UnmarshalText for enum values"` + MarshalEnum bool `enum_marshal:"Generate MarshalText for enum values"` + UnmarshalEnum bool `enum_unmarshal:"Generate UnmarshalText for enum values"` GenerateSetter bool `gen_setter:"Generate Set* methods for fields"` GenDatabaseTag bool `gen_db_tag:"Generate 'db:$field' tag"` GenOmitEmptyTag bool `omitempty_for_optional:"Generate 'omitempty' tags for optional fields."` @@ -56,6 +58,8 @@ type Features struct { var defaultFeatures = Features{ MarshalEnumToText: false, + MarshalEnum: false, + UnmarshalEnum: false, GenerateSetter: false, GenDatabaseTag: false, GenOmitEmptyTag: true, diff --git a/generator/golang/templates/enum.go b/generator/golang/templates/enum.go index 895f4a5d..37b7047e 100644 --- a/generator/golang/templates/enum.go +++ b/generator/golang/templates/enum.go @@ -53,12 +53,16 @@ func {{$EnumType}}FromString(s string) ({{$EnumType}}, error) { func {{$EnumType}}Ptr(v {{$EnumType}} ) *{{$EnumType}} { return &v } -{{- if Features.MarshalEnumToText}} +{{- if or Features.MarshalEnumToText Features.MarshalEnum}} func (p {{$EnumType}}) MarshalText() ([]byte, error) { return []byte(p.String()), nil } +{{end}}{{/* if or Features.MarshalEnumToText Features.MarshalEnum */}} + +{{- if or Features.MarshalEnumToText Features.UnmarshalEnum}} + func (p *{{$EnumType}}) UnmarshalText(text []byte) error { q, err := {{$EnumType}}FromString(string(text)) if err != nil { @@ -67,7 +71,7 @@ func (p *{{$EnumType}}) UnmarshalText(text []byte) error { *p = q return nil } -{{- end}}{{/* if Features.MarshalEnumToText */}} +{{end}}{{/* if or Features.MarshalEnumToText Features.UnmarshalEnum */}} {{- if Features.ScanValueForEnum}} {{- UseStdLibrary "sql" "driver"}} From 5325945c72ae7516a47a5f3c5a6631a87cedec0b Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:27:16 +0800 Subject: [PATCH 12/19] feat: add yaml-config support for trimmer (#131) Co-authored-by: Li2CO3 --- generator/golang/backend.go | 2 +- tool/trimmer/args.go | 4 +- tool/trimmer/dirTree.go | 81 ------------------- tool/trimmer/dirTree_test.go | 58 ------------- tool/trimmer/main.go | 41 +++++----- .../test_cases/tests/dir/dir2/test.thrift | 11 +++ .../tests/dir/dir2/trim_config.yaml | 19 +++++ .../tests/dir/dir3/dir4/another.thrift | 16 ++++ tool/trimmer/trim/config.go | 50 ++++++++++++ tool/trimmer/trim/mark.go | 5 ++ tool/trimmer/trim/traversal.go | 4 +- tool/trimmer/trim/trimmer.go | 43 ++++++++-- tool/trimmer/trim/trimmer_test.go | 47 +++++++++++ 13 files changed, 210 insertions(+), 171 deletions(-) delete mode 100644 tool/trimmer/dirTree.go delete mode 100644 tool/trimmer/dirTree_test.go create mode 100644 tool/trimmer/test_cases/tests/dir/dir2/trim_config.yaml create mode 100644 tool/trimmer/test_cases/tests/dir/dir3/dir4/another.thrift create mode 100644 tool/trimmer/trim/config.go diff --git a/generator/golang/backend.go b/generator/golang/backend.go index d96e843d..e653c1db 100644 --- a/generator/golang/backend.go +++ b/generator/golang/backend.go @@ -86,7 +86,7 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R g.log = log g.prepareUtilities() if g.utils.Features().TrimIDL { - err := trim.TrimAST(req.AST, nil, false) + err := trim.TrimAST(&trim.TrimASTArg{Ast: req.AST, TrimMethods: nil, Preserve: nil}) if err != nil { g.log.Warn("trim error:", err.Error()) } diff --git a/tool/trimmer/args.go b/tool/trimmer/args.go index eee0f637..ca3cea05 100644 --- a/tool/trimmer/args.go +++ b/tool/trimmer/args.go @@ -63,8 +63,8 @@ func (a *Arguments) BuildFlags() *flag.FlagSet { f.Var(&a.Methods, "m", "") f.Var(&a.Methods, "method", "") - f.StringVar(&a.Preserve, "p", "true", "") - f.StringVar(&a.Preserve, "preserve", "true", "") + f.StringVar(&a.Preserve, "p", "", "") + f.StringVar(&a.Preserve, "preserve", "", "") f.Usage = help return f diff --git a/tool/trimmer/dirTree.go b/tool/trimmer/dirTree.go deleted file mode 100644 index c7c39683..00000000 --- a/tool/trimmer/dirTree.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2023 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "errors" - "fmt" - "os" - "path/filepath" -) - -// create directory-tree before dump -func createDirTree(sourceDir, destinationDir string) { - err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if info.IsDir() { - if path[len(sourceDir)-1] != filepath.Separator { - path = path + string(filepath.Separator) - } - newDir := filepath.Join(destinationDir, path[len(sourceDir):]) - err := os.MkdirAll(newDir, os.ModePerm) - if err != nil { - return errors.New("create dir tree error:" + err.Error()) - } - } - return nil - }) - if err != nil { - fmt.Printf("manage output error: %v\n", err) - os.Exit(2) - } -} - -// remove empty directory of output dir-tree -func removeEmptyDir(source string) { - files, err := os.ReadDir(source) - if err != nil { - return - } - for _, file := range files { - if file.IsDir() { - removeEmptyDir(source + string(filepath.Separator) + file.Name()) - } - } - empty, err := isDirectoryEmpty(source) - if empty || err != nil { - _ = os.RemoveAll(source) - } -} - -func isDirectoryEmpty(path string) (bool, error) { - dir, err := os.Open(path) - if err != nil { - return false, err - } - defer dir.Close() - - _, err = dir.Readdirnames(1) - if err == nil { - return false, nil - } - - if errors.Is(err, os.ErrNotExist) { - return true, nil - } - return false, err -} diff --git a/tool/trimmer/dirTree_test.go b/tool/trimmer/dirTree_test.go deleted file mode 100644 index 7ae23697..00000000 --- a/tool/trimmer/dirTree_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2023 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "os" - "path/filepath" - "testing" - - "github.com/cloudwego/thriftgo/pkg/test" -) - -func TestDirTree(t *testing.T) { - _ = os.RemoveAll("trimmer_test") - createDirTree("test_cases", "trimmer_test") - fileCount, dirCount, err := countFilesAndSubdirectories("trimmer_test") - test.Assert(t, err == nil) - test.Assert(t, fileCount == 0) - test.Assert(t, dirCount == 4) - removeEmptyDir("trimmer_test") - _, err = os.ReadDir("trimmer_test") - test.Assert(t, err != nil) -} - -func countFilesAndSubdirectories(dirPath string) (int, int, error) { - var fileCount, dirCount int - files, err := os.ReadDir(dirPath) - if err != nil { - return 0, 0, err - } - for _, file := range files { - if file.IsDir() { - dirCount++ - subDirPath := filepath.Join(dirPath, file.Name()) - subFileCount, subDirCount, err := countFilesAndSubdirectories(subDirPath) - if err != nil { - return 0, 0, err - } - fileCount += subFileCount - dirCount += subDirCount - } else { - fileCount++ - } - } - return fileCount, dirCount, nil -} diff --git a/tool/trimmer/main.go b/tool/trimmer/main.go index b955fda0..fa97afac 100644 --- a/tool/trimmer/main.go +++ b/tool/trimmer/main.go @@ -53,13 +53,14 @@ func main() { os.Exit(0) } - preserve := true + var preserveInput *bool if a.Preserve != "" { - preserve, err = strconv.ParseBool(a.Preserve) + preserve, err := strconv.ParseBool(a.Preserve) if err != nil { help() os.Exit(2) } + preserveInput = &preserve } // parse file to ast @@ -74,7 +75,9 @@ func main() { check(semantic.ResolveSymbols(ast)) // trim ast - check(trim.TrimAST(ast, a.Methods, !preserve)) + check(trim.TrimAST(&trim.TrimASTArg{ + Ast: ast, TrimMethods: a.Methods, Preserve: preserveInput, + })) // dump the trimmed ast to idl idl, err := dump.DumpIDL(ast) @@ -117,16 +120,7 @@ func main() { println("output-dir should be set outside of -r base-dir to avoid overlay") os.Exit(2) } - createDirTree(a.Recurse, a.OutputFile) recurseDump(ast, a.Recurse, a.OutputFile) - relativePath, err := filepath.Rel(a.Recurse, a.IDL) - if err != nil { - println("-r input err, range should cover all the target IDLs;", err.Error()) - os.Exit(2) - } - outputFileUrl := filepath.Join(a.OutputFile, relativePath) - check(writeStringToFile(outputFileUrl, idl)) - removeEmptyDir(a.OutputFile) } else { check(writeStringToFile(a.OutputFile, idl)) } @@ -139,16 +133,21 @@ func recurseDump(ast *parser.Thrift, sourceDir, outDir string) { if ast == nil { return } + out, err := dump.DumpIDL(ast) + check(err) + relativeUrl, err := filepath.Rel(sourceDir, ast.Filename) + if err != nil { + println("-r input err, range should cover all the target IDLs;", err.Error()) + os.Exit(2) + } + outputFileUrl := filepath.Join(outDir, relativeUrl) + err = os.MkdirAll(filepath.Dir(outputFileUrl), os.ModePerm) + if err != nil { + println("mkdir", filepath.Dir(outputFileUrl), "error:", err.Error()) + os.Exit(2) + } + check(writeStringToFile(outputFileUrl, out)) for _, includes := range ast.Includes { - out, err := dump.DumpIDL(includes.Reference) - check(err) - relativeUrl, err := filepath.Rel(sourceDir, includes.Reference.Filename) - if err != nil { - println("-r input err, range should cover all the target IDLs;", err.Error()) - os.Exit(2) - } - outputFileUrl := filepath.Join(outDir, relativeUrl) - check(writeStringToFile(outputFileUrl, out)) recurseDump(includes.Reference, sourceDir, outDir) } } diff --git a/tool/trimmer/test_cases/tests/dir/dir2/test.thrift b/tool/trimmer/test_cases/tests/dir/dir2/test.thrift index a7359c51..804cb5d8 100644 --- a/tool/trimmer/test_cases/tests/dir/dir2/test.thrift +++ b/tool/trimmer/test_cases/tests/dir/dir2/test.thrift @@ -13,8 +13,19 @@ // limitations under the License. include "../../../sample1.thrift" +include "../dir3/dir4/another.thrift" // @preserve struct TestStruct{ 1: sample1.Person person + 2: another.AnotherStruct another +} + +service TestService{ + void func1() + another.AnotherStruct func2() + void func3() +} + +union useless{ } \ No newline at end of file diff --git a/tool/trimmer/test_cases/tests/dir/dir2/trim_config.yaml b/tool/trimmer/test_cases/tests/dir/dir2/trim_config.yaml new file mode 100644 index 00000000..f65338b1 --- /dev/null +++ b/tool/trimmer/test_cases/tests/dir/dir2/trim_config.yaml @@ -0,0 +1,19 @@ +# Copyright 2023 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +methods: + - "TestService.func1" + - "TestService.func3" +preserve: true +preserved_structs: + - "useless" \ No newline at end of file diff --git a/tool/trimmer/test_cases/tests/dir/dir3/dir4/another.thrift b/tool/trimmer/test_cases/tests/dir/dir3/dir4/another.thrift new file mode 100644 index 00000000..20aff287 --- /dev/null +++ b/tool/trimmer/test_cases/tests/dir/dir3/dir4/another.thrift @@ -0,0 +1,16 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +struct AnotherStruct{ +} \ No newline at end of file diff --git a/tool/trimmer/trim/config.go b/tool/trimmer/trim/config.go new file mode 100644 index 00000000..28786bab --- /dev/null +++ b/tool/trimmer/trim/config.go @@ -0,0 +1,50 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trim + +import ( + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +var DefaultYamlFileName = "trim_config.yaml" + +type YamlArguments struct { + Methods []string `yaml:"methods,omitempty"` + Preserve *bool `yaml:"preserve,omitempty"` + PreservedStructs []string `yaml:"preserved_structs,omitempty"` +} + +func ParseYamlConfig(path string) *YamlArguments { + cfg := YamlArguments{} + dataBytes, err := os.ReadFile(filepath.Join(path, DefaultYamlFileName)) + if err != nil { + return nil + } + fmt.Println("using trim config:", filepath.Join(path, DefaultYamlFileName)) + err = yaml.Unmarshal(dataBytes, &cfg) + if err != nil { + fmt.Println("unmarshal yaml config fail:", err) + return nil + } + if cfg.Preserve == nil { + t := true + cfg.Preserve = &t + } + return &cfg +} diff --git a/tool/trimmer/trim/mark.go b/tool/trimmer/trim/mark.go index 55adbb1a..8bd5e903 100644 --- a/tool/trimmer/trim/mark.go +++ b/tool/trimmer/trim/mark.go @@ -261,5 +261,10 @@ func (t *Trimmer) checkPreserve(theStruct *parser.StructLike) bool { if t.forceTrimming { return false } + for _, name := range t.preservedStructs { + if name == theStruct.Name { + return true + } + } return t.preserveRegex.MatchString(strings.ToLower(theStruct.ReservedComments)) } diff --git a/tool/trimmer/trim/traversal.go b/tool/trimmer/trim/traversal.go index ddba0cf8..bc4441e0 100644 --- a/tool/trimmer/trim/traversal.go +++ b/tool/trimmer/trim/traversal.go @@ -40,7 +40,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { var listUnion []*parser.StructLike for i := range ast.Unions { - if t.marks[filename][ast.Unions[i]] { + if t.marks[filename][ast.Unions[i]] || t.checkPreserve(ast.Unions[i]) { listUnion = append(listUnion, ast.Unions[i]) } } @@ -48,7 +48,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { var listException []*parser.StructLike for i := range ast.Exceptions { - if t.marks[filename][ast.Exceptions[i]] { + if t.marks[filename][ast.Exceptions[i]] || t.checkPreserve(ast.Exceptions[i]) { listException = append(listException, ast.Exceptions[i]) } } diff --git a/tool/trimmer/trim/trimmer.go b/tool/trimmer/trim/trimmer.go index 39a143f2..97fe1bdd 100644 --- a/tool/trimmer/trim/trimmer.go +++ b/tool/trimmer/trim/trimmer.go @@ -32,14 +32,44 @@ type Trimmer struct { marks map[string]map[interface{}]bool outDir string // use -m - trimMethods []string - trimMethodValid []bool - preserveRegex *regexp.Regexp - forceTrimming bool + trimMethods []string + trimMethodValid []bool + preserveRegex *regexp.Regexp + forceTrimming bool + preservedStructs []string } -// TrimAST trim the single AST, pass method names if -m specified -func TrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool) error { +type TrimASTArg struct { + Ast *parser.Thrift + TrimMethods []string + Preserve *bool +} + +// TrimAST parse the cfg and trim the single AST +func TrimAST(arg *TrimASTArg) error { + var preservedStructs []string + if wd, err := os.Getwd(); err == nil { + cfg := ParseYamlConfig(wd) + if cfg != nil { + if len(arg.TrimMethods) == 0 && len(cfg.Methods) > 0 { + arg.TrimMethods = cfg.Methods + } + if arg.Preserve == nil && !(*cfg.Preserve) { + preserve := false + arg.Preserve = &preserve + } + preservedStructs = cfg.PreservedStructs + } + } + forceTrim := false + if arg.Preserve != nil { + forceTrim = !*arg.Preserve + } + return doTrimAST(arg.Ast, arg.TrimMethods, forceTrim, preservedStructs) +} + +// doTrimAST trim the single AST, pass method names if -m specified +func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, preservedStructs []string) error { trimmer, err := newTrimmer(nil, "") if err != nil { return err @@ -59,6 +89,7 @@ func TrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool) error } } } + trimmer.preservedStructs = preservedStructs trimmer.markAST(ast) trimmer.traversal(ast, ast.Filename) if path := parser.CircleDetect(ast); len(path) > 0 { diff --git a/tool/trimmer/trim/trimmer_test.go b/tool/trimmer/trim/trimmer_test.go index 5687a2d1..0dd219ae 100644 --- a/tool/trimmer/trim/trimmer_test.go +++ b/tool/trimmer/trim/trimmer_test.go @@ -85,3 +85,50 @@ func TestInclude(t *testing.T) { test.Assert(t, len(ast.Includes) == 1) test.Assert(t, ast.Includes[0].Used == nil) } + +func TestTrimMethod(t *testing.T) { + filename := filepath.Join("..", "test_cases", "tests", "dir", "dir2", "test.thrift") + ast, err := parser.ParseFile(filename, nil, true) + check(err) + if path := parser.CircleDetect(ast); len(path) > 0 { + check(fmt.Errorf("found include circle:\n\t%s", path)) + } + checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) + _, err = checker.CheckAll(ast) + check(err) + check(semantic.ResolveSymbols(ast)) + + methods := make([]string, 1) + methods[0] = "func1" + + err = TrimAST(&TrimASTArg{ + Ast: ast, + TrimMethods: methods, + Preserve: nil, + }) + check(err) + test.Assert(t, len(ast.Services[0].Functions) == 1) +} + +func TestPreserve(t *testing.T) { + filename := filepath.Join("..", "test_cases", "tests", "dir", "dir2", "test.thrift") + ast, err := parser.ParseFile(filename, nil, true) + check(err) + if path := parser.CircleDetect(ast); len(path) > 0 { + check(fmt.Errorf("found include circle:\n\t%s", path)) + } + checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) + _, err = checker.CheckAll(ast) + check(err) + check(semantic.ResolveSymbols(ast)) + + preserve := false + + err = TrimAST(&TrimASTArg{ + Ast: ast, + TrimMethods: nil, + Preserve: &preserve, + }) + check(err) + test.Assert(t, len(ast.Structs) == 0) +} From 74820aa77b8096489c5ef7771bf1e38dc7e1383b Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Wed, 1 Nov 2023 20:04:39 +0800 Subject: [PATCH 13/19] feat: add regexp support for method-based trimming (#139) --- go.mod | 1 + go.sum | 2 ++ tool/trimmer/trim/mark.go | 4 ++-- tool/trimmer/trim/trimmer.go | 8 ++++++-- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 1ed04708..dba1f689 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.13 require ( github.com/apache/thrift v0.13.0 + github.com/dlclark/regexp2 v1.10.0 // indirect golang.org/x/text v0.6.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index a361fe19..6bf2817b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/tool/trimmer/trim/mark.go b/tool/trimmer/trim/mark.go index 8bd5e903..4bbdd190 100644 --- a/tool/trimmer/trim/mark.go +++ b/tool/trimmer/trim/mark.go @@ -44,7 +44,7 @@ func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename if t.trimMethods != nil { funcName := svc.Name + "." + function.Name for i, method := range t.trimMethods { - if funcName == method { + if ok, _ := method.MatchString(funcName); ok { t.marks[filename][svc] = true t.markFunction(function, ast, filename) t.trimMethodValid[i] = true @@ -216,7 +216,7 @@ func (t *Trimmer) traceExtendMethod(father, svc *parser.Service, ast *parser.Thr for _, function := range svc.Functions { funcName := father.Name + "." + function.Name for i, method := range t.trimMethods { - if funcName == method { + if ok, _ := method.MatchString(funcName); ok { t.marks[filename][svc] = true t.markFunction(function, ast, filename) t.trimMethodValid[i] = true diff --git a/tool/trimmer/trim/trimmer.go b/tool/trimmer/trim/trimmer.go index 97fe1bdd..fed4983c 100644 --- a/tool/trimmer/trim/trimmer.go +++ b/tool/trimmer/trim/trimmer.go @@ -20,6 +20,8 @@ import ( "regexp" "strings" + "github.com/dlclark/regexp2" + "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/semantic" ) @@ -32,7 +34,7 @@ type Trimmer struct { marks map[string]map[interface{}]bool outDir string // use -m - trimMethods []string + trimMethods []*regexp2.Regexp trimMethodValid []bool preserveRegex *regexp.Regexp forceTrimming bool @@ -75,7 +77,7 @@ func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, pre return err } trimmer.asts[ast.Filename] = ast - trimmer.trimMethods = trimMethods + trimmer.trimMethods = make([]*regexp2.Regexp, len(trimMethods)) trimmer.trimMethodValid = make([]bool, len(trimMethods)) trimmer.forceTrimming = forceTrimming for i, method := range trimMethods { @@ -88,6 +90,8 @@ func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, pre os.Exit(2) } } + trimmer.trimMethods[i], err = regexp2.Compile(trimMethods[i], 0) + check(err) } trimmer.preservedStructs = preservedStructs trimmer.markAST(ast) From 40cc17938b64e2e0b0b1b825059f5156af2e3818 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Thu, 2 Nov 2023 14:57:34 +0800 Subject: [PATCH 14/19] fix: fix reflection ref import (#140) --- generator/golang/templates/reflection/reflection_ref_tpl.go | 1 - 1 file changed, 1 deletion(-) diff --git a/generator/golang/templates/reflection/reflection_ref_tpl.go b/generator/golang/templates/reflection/reflection_ref_tpl.go index 90c08083..17717922 100644 --- a/generator/golang/templates/reflection/reflection_ref_tpl.go +++ b/generator/golang/templates/reflection/reflection_ref_tpl.go @@ -26,7 +26,6 @@ import ( {{end}} "reflect" - {{.RefPackage}} "{{.RefPath}}" "github.com/cloudwego/thriftgo/thrift_reflection" ) From 47fd70033721b4b3759f5fdeb7c898f2e6b9210a Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Thu, 9 Nov 2023 19:43:29 +0800 Subject: [PATCH 15/19] fix: fix filepath for windows (#142) --- generator/golang/templates/reflection/reflection_ref_tpl.go | 4 ++-- generator/golang/util.go | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/generator/golang/templates/reflection/reflection_ref_tpl.go b/generator/golang/templates/reflection/reflection_ref_tpl.go index 17717922..eb5e355e 100644 --- a/generator/golang/templates/reflection/reflection_ref_tpl.go +++ b/generator/golang/templates/reflection/reflection_ref_tpl.go @@ -45,9 +45,9 @@ func init() { } type x struct{} replacer := &thrift_reflection.FileDescriptorReplacer{ - RemoteGoPkgPath: "{{ .RefPath }}", + RemoteGoPkgPath: {{ backquoted .RefPath }}, CurrentGoPkgPath: reflect.TypeOf(x{}).PkgPath(), - CurrentFilepath: "{{ $IDLPath}}", + CurrentFilepath: {{ backquoted $IDLPath }}, Matcher: "{{ .GetFirstDescriptor }}", } file_{{$IDLName}}_thrift = thrift_reflection.ReplaceFileDescriptor(replacer) diff --git a/generator/golang/util.go b/generator/golang/util.go index a0a7f290..6f1a6e68 100644 --- a/generator/golang/util.go +++ b/generator/golang/util.go @@ -427,6 +427,7 @@ func (cu *CodeUtils) BuildFuncMap() template.FuncMap { }) return ret }, + "backquoted": BackQuoted, } return m } @@ -461,3 +462,7 @@ func JoinPath(elem ...string) string { } return filepath.Join(elem...) } + +func BackQuoted(s string) string { + return "`" + s + "`" +} From d3508eeb6136bc20ba2f79a04ac878a1595c1cc5 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Tue, 14 Nov 2023 19:58:42 +0800 Subject: [PATCH 16/19] optimize: add an option to control code ref file name (#143) --- generator/golang/backend.go | 15 +++++++++++---- generator/golang/option.go | 1 + version/version.go | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/generator/golang/backend.go b/generator/golang/backend.go index e653c1db..df2514b2 100644 --- a/generator/golang/backend.go +++ b/generator/golang/backend.go @@ -170,6 +170,7 @@ func (g *GoBackend) executeTemplates() { } func (g *GoBackend) renderOneFile(ast *parser.Thrift) error { + keepName := g.utils.Features().KeepCodeRefName path := g.utils.CombineOutputPath(g.req.OutputPath, ast) filename := filepath.Join(path, g.utils.GetFilename(ast)) localScope, refScope, err := BuildRefScope(g.utils, ast) @@ -180,12 +181,12 @@ func (g *GoBackend) renderOneFile(ast *parser.Thrift) error { if err != nil { return err } - err = g.renderByTemplate(refScope, g.refTpl, ToRefFilename(filename)) + err = g.renderByTemplate(refScope, g.refTpl, ToRefFilename(keepName, filename)) if err != nil { return err } if g.utils.Features().WithReflection { - err = g.renderByTemplate(refScope, g.reflectionRefTpl, ToReflectionRefFilename(filename)) + err = g.renderByTemplate(refScope, g.reflectionRefTpl, ToReflectionRefFilename(keepName, filename)) if err != nil { return err } @@ -194,7 +195,10 @@ func (g *GoBackend) renderOneFile(ast *parser.Thrift) error { return nil } -func ToRefFilename(filename string) string { +func ToRefFilename(keepName bool, filename string) string { + if keepName { + return filename + } return strings.TrimSuffix(filename, ".go") + "-ref.go" } @@ -202,7 +206,10 @@ func ToReflectionFilename(filename string) string { return strings.TrimSuffix(filename, ".go") + "-reflection.go" } -func ToReflectionRefFilename(filename string) string { +func ToReflectionRefFilename(keepName bool, filename string) string { + if keepName { + return ToReflectionFilename(filename) + } return strings.TrimSuffix(filename, ".go") + "-reflection-ref.go" } diff --git a/generator/golang/option.go b/generator/golang/option.go index 2e32e801..1d53738b 100644 --- a/generator/golang/option.go +++ b/generator/golang/option.go @@ -53,6 +53,7 @@ type Features struct { EnumAsINT32 bool `enum_as_int_32:"Generate enum type as int32"` CodeRefSlim bool `code_ref_slim:"Genenerate code ref by given idl-ref.yaml with less refs to avoid conflict"` CodeRef bool `code_ref:"Genenerate code ref by given idl-ref.yaml"` + KeepCodeRefName bool `keep_code_ref_name:"Genenerate code ref but still keep file name."` TrimIDL bool `trim_idl:"Simplify IDL to the most concise form before generating code."` } diff --git a/version/version.go b/version/version.go index 49115909..41817c03 100644 --- a/version/version.go +++ b/version/version.go @@ -14,4 +14,4 @@ package version -const ThriftgoVersion = "0.3.2" +const ThriftgoVersion = "0.3.3" From 4fc8a6f413032987c608c0201e05fa0cd2bd4cda Mon Sep 17 00:00:00 2001 From: Yi Duan Date: Tue, 21 Nov 2023 20:50:14 +0800 Subject: [PATCH 17/19] feat: support FieldMask (#122) --- .github/workflows/push-check.yml | 6 +- .gitignore | 1 + CREDITS | 14 + fieldmask/README.md | 50 ++ fieldmask/api_test.go | 443 +++++++++++++++ fieldmask/mapper.go | 38 ++ fieldmask/mask.go | 186 +++++++ fieldmask/path.go | 512 ++++++++++++++++++ fieldmask/storage.go | 243 +++++++++ fieldmask/utils.go | 463 ++++++++++++++++ generator/golang/imports.go | 1 + generator/golang/option.go | 2 + generator/golang/read_write_context.go | 10 + generator/golang/templates/init.go | 1 + .../templates/reflection/reflection_tpl.go | 29 + generator/golang/templates/struct.go | 279 ++++++++-- generator/golang/thrift.go | 20 + generator/golang/util.go | 3 + go.mod | 2 +- internal/test_util/generator.go | 73 +++ test/golang/fieldmask/a.thrift | 90 +++ test/golang/fieldmask/go.mod | 21 + test/golang/fieldmask/go.sum | 39 ++ test/golang/fieldmask/main_test.go | 292 ++++++++++ test/golang/fieldmask/run_test.sh | 37 ++ thrift_reflection/descriptor-extend.go | 33 +- utils/ast_util.go | 69 +++ utils/string_utils.go | 8 + 28 files changed, 2924 insertions(+), 41 deletions(-) create mode 100644 fieldmask/README.md create mode 100644 fieldmask/api_test.go create mode 100644 fieldmask/mapper.go create mode 100644 fieldmask/mask.go create mode 100644 fieldmask/path.go create mode 100644 fieldmask/storage.go create mode 100644 fieldmask/utils.go create mode 100644 internal/test_util/generator.go create mode 100644 test/golang/fieldmask/a.thrift create mode 100644 test/golang/fieldmask/go.mod create mode 100644 test/golang/fieldmask/go.sum create mode 100644 test/golang/fieldmask/main_test.go create mode 100755 test/golang/fieldmask/run_test.sh create mode 100644 utils/ast_util.go diff --git a/.github/workflows/push-check.yml b/.github/workflows/push-check.yml index b34c5e61..3ab8c367 100644 --- a/.github/workflows/push-check.yml +++ b/.github/workflows/push-check.yml @@ -33,12 +33,12 @@ jobs: go install mvdan.cc/gofumpt@v0.2.0 echo "install done!" set -e - if [[ -n "$(gofumpt -l -extra .)" ]]; then + if [[ -n "$(gofumpt -l .)" ]]; then echo "gofumpt found formatting issues." - gofumpt -l -extra . + gofumpt -l . exit 1 fi - test -z "$(gofumpt -l -extra .)" + test -z "$(gofumpt -l .)" echo "test done!" go vet -stdmethods=false $(go list ./...) echo "go vet done!" diff --git a/.gitignore b/.gitignore index 35517d12..ee7f53f8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ gen-* .vscode .idea tool/trimmer/trimmer_test +output diff --git a/CREDITS b/CREDITS index e69de29b..85ee3e7d 100644 --- a/CREDITS +++ b/CREDITS @@ -0,0 +1,14 @@ +// Copyright 2023 ByteDance Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + diff --git a/fieldmask/README.md b/fieldmask/README.md new file mode 100644 index 00000000..f2e844a1 --- /dev/null +++ b/fieldmask/README.md @@ -0,0 +1,50 @@ + + +# ThriftPath RFC + +## What is thrift path? +A path string represents a arbitrary endpoint of thrift object. It is used for locating data from thrift root message, and defined from-top-to-bottom. +For exapmle, a thrift message defined as below: +```thrift +struct Example { + 1: string Foo, + 2: i64 Bar + 3: Example Self +} +``` +A thrift path `$.Foo` represents the string value of Example.Foo, and `$.Self.Bar` represents the secondary layer i64 value of Example.Self.Bar +Since thrift has four nesting types (LIST/SET/MAP/STRUCT), thrift path should also support locating elements in all these types' object, not only STRUCT. + +## Syntax +Here are basic hypothesis: +- `fieldname` is the field name of a field in a struct, it **MUST ONLY** contain '[a-zA-Z]' alphabet letters, integer numbers and char '_'. +- `index` is the index of a element in a list or set, it **MUST ONLY** contain integer numbers. +- `key` is the string-typed key of a element in a map, it can contain any letters, but it **MUST** be a quoted string. +- `id` is the integer-typed key of a element in a map, it **MUST ONLY** contain integer numbers. + +Here is detailed syntax: + +ThriftPath | Description +-- | -- +$ | the root object,every path must start with it. +.`fieldname` | get the child field of a struct corepsonding to fieldname. For example, `$.FieldA.ChildrenB` +[`index`,index...] | get any number of elements in an List/Set corepsonding to indices. Indices must be integer.For example: `$.FieldList[1,3,4]` .Notice: a index beyond actual list size can written but is useless. +{"key","key"...} | get any number of values corepsonding to key in a string-typed-key map. For example: `$.StrMap{"abcd","1234"}` +{id,id...} | get the child field with specific id in a integer-typed-key map. For example, `$.IntMap{1,2}` +\* | get **ALL** fields/elements, that is: `$.StrMap{*}.FieldX` menas gets all the elements' FieldX in a map Root.StrMap; `$.List[*].FieldX` means get all the elements' FieldX in a list Root.List + + diff --git a/fieldmask/api_test.go b/fieldmask/api_test.go new file mode 100644 index 00000000..05d9c160 --- /dev/null +++ b/fieldmask/api_test.go @@ -0,0 +1,443 @@ +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fieldmask + +import ( + "strings" + "testing" + + "github.com/cloudwego/thriftgo/parser" + "github.com/cloudwego/thriftgo/thrift_reflection" +) + +var baseIDL = ` +namespace go base + +struct TrafficEnv { + 0: string Name = "", + 1: bool Open = false, + 2: string Env = "", + 256: i64 Code, +} + +struct Base { + 0: string Addr = "", + 1: string LogID = "", + 2: string Caller = "", + 5: optional TrafficEnv TrafficEnv, + 6: optional list Extra, + 256: MetaInfo Meta, +} + +struct ExtraInfo { + 1: map IntMap + 2: map StrMap + 3: list List + 4: set Set +} + +struct Val { + 1: string A, + 2: string B, +} + +struct MetaInfo { + 1: map F1, + 2: map F2, + 3: list F3, + 3: Base Base, +} + +struct BaseResp { + 1: string StatusMessage = "", + 2: i32 StatusCode = 0, + 3: optional map Extra, +}` + +func GetDescriptor(IDL string, root string) *thrift_reflection.TypeDescriptor { + ast, err := parser.ParseString("a.thrift", IDL) + if err != nil { + panic(err.Error()) + } + fd := thrift_reflection.RegisterAST(ast) + st := fd.GetStructDescriptor(root) + return &thrift_reflection.TypeDescriptor{ + Filepath: st.Filepath, + Name: st.Name, + } +} + +func TestNewFieldMask(t *testing.T) { + type args struct { + IDL string + rootStruct string + paths []string + inMasks []string + notInMasks []string + err []error + } + tests := []struct { + name string + args args + want *FieldMask + }{ + { + name: "Struct", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + paths: []string{"$.LogID", "$.TrafficEnv.Open", "$.TrafficEnv.Env", "$.Meta"}, + + inMasks: []string{"$.Meta.F1", "$.Meta.F2", "$.Meta.Base.Caller"}, + notInMasks: []string{"$.TrafficEnv.Name", "$.TrafficEnv.Code", "$.Caller", "$.Addr", "$.Extra"}, + }, + }, + { + name: "List/Set", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + paths: []string{"$.Extra[0]", "$.Extra[1].List", "$.Extra[2].Set[0,1]", "$.Extra[4,5].List[*]"}, + + inMasks: []string{"$.Extra[0].List", "$.Extra[2].Set[0].A", "$.Extra[2].Set[1].A", "$.Extra[4].List[0]", "$.Extra[4,5].List[0]", "$.Extra[1,4,5].List"}, + notInMasks: []string{"$.Extra[1].Set", "$.Extra[1].IntMap", "$.Extra[3]", "$.Extra[3,4].Set"}, + }, + }, + { + name: "Int Map", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + paths: []string{"$.Extra[0].IntMap{0}", "$.Extra[0].IntMap{1}.A", "$.Extra[0].IntMap{1}.B", "$.Extra[0].IntMap{2}.A", "$.Extra[0].IntMap{4,5}.A", "$.Meta.F2{*}.TrafficEnv"}, + inMasks: []string{"$.Extra[0].IntMap{0}.A", "$.Extra[0].IntMap{0}.B", "$.Extra[0].IntMap{4}.A", "$.Extra[0].IntMap{5}.A", "$.Meta.F2{0}.TrafficEnv.Env", "$.Meta.F2{*}.TrafficEnv.Env"}, + notInMasks: []string{"$.Extra[0].IntMap{2}.B", "$.Extra[0].IntMap{3}", "$.Extra[0].IntMap{4}.B", "$.Extra[0].IntMap{5}.B", "$.Meta.F2{0}.Addr", "$.Meta.F2{*}.Addr"}, + }, + }, + { + name: "Union", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + paths: []string{"$.Extra[0].List", "$.Extra[*].Set", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr"}, + inMasks: []string{"$.Extra[*].Set[0]", "$.Meta.F2{1}.Addr"}, + notInMasks: []string{"$.Extra[0].List", "$.Meta.F2[0].LogID"}, + }, + }, + { + name: "String Map", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + paths: []string{"$.Extra[0].StrMap{\"x\"}", "$.Extra[0].StrMap{\"a\"}.A", "$.Extra[0].StrMap{\"b\"}.B", "$.Extra[0].StrMap{\"c\",\"d\"}", "$.Extra[0].StrMap{\"e\",\"f\"}.A"}, + inMasks: []string{"$.Extra[0].StrMap{\"x\"}.A", "$.Extra[0].StrMap{\"x\"}.B", "$.Extra[0].StrMap{\"c\"}.A", "$.Extra[0].StrMap{\"c\",\"d\",\"e\",\"f\"}.A"}, + notInMasks: []string{"$.Extra[0].StrMap{\"a\"}.B", "$.Extra[0].StrMap{\"b\"}.A", "$.Extra[0].StrMap{\"s\"}", "$.Extra[0].StrMap{\"s\",\"c\"}", "$.Extra[0].StrMap{\"d\",\"e\"}.B"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // defer func() { + // if v := recover(); v != nil { + // if tt.args.err == nil || v != tt.args.err { + // t.Fatal("panic: ", v) + // } + // } + // }() + + st := GetDescriptor(tt.args.IDL, tt.args.rootStruct) + got, err := NewFieldMask(st, tt.args.paths...) + if tt.args.err != nil { + if err == nil { + t.Fatal(err) + } + return + } + if err != nil { + t.Fatal(err) + } + + retry := true + begin: + + println("fieldmask:") + println(got.String(st)) + // spew.Dump(got) + + if tt.name != "Union" { + for _, path := range tt.args.paths { + println("[paths] ", path) + if !got.PathInMask(st, path) { + t.Fatal(path) + } + } + } + + for _, path := range tt.args.inMasks { + println("[inMasks] ", path) + if !got.PathInMask(st, path) { + t.Fatal(path) + } + } + for _, path := range tt.args.notInMasks { + println("[notInMasks] ", path) + if got.PathInMask(st, path) { + t.Fatal(path) + } + } + + if retry { + got.reset() + if err := got.init(st, tt.args.paths...); err != nil { + t.Fatal(err) + } + retry = false + goto begin + } + }) + } +} + +func TestErrors(t *testing.T) { + type args struct { + IDL string + rootStruct string + path []string + err string + } + tests := []struct { + name string + args args + want *FieldMask + }{ + { + name: "desc struct", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + path: []string{"$.LogID.X"}, + err: `Descriptor "string" isn't STRUCT`, + }, + }, + { + name: "desc list", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + path: []string{"$.LogID[1]"}, + err: `Descriptor "string" isn't LIST or SET`, + }, + }, + { + name: "desc map", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + path: []string{"$.LogID{1}"}, + err: `Descriptor "string" isn't MAP`, + }, + }, + { + name: "desc map key", + args: args{ + IDL: baseIDL, + rootStruct: "ExtraInfo", + path: []string{"$.IntMap{\"a\"}"}, + err: `expect integer but got string`, + }, + }, + { + name: "desc map key", + args: args{ + IDL: baseIDL, + rootStruct: "ExtraInfo", + path: []string{"$.StrMap{1}"}, + err: `expect string but got integer`, + }, + }, + { + name: "syntax index", + args: args{ + IDL: baseIDL, + rootStruct: "ExtraInfo", + path: []string{"$.List[\"1\"]"}, + err: `isn't literal`, + }, + }, + { + name: "fields conflict", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + path: []string{"$.TrafficEnv", "$.TrafficEnv.Env"}, + err: `onflicts with previously-set all (*) fields`, + }, + }, + { + name: "index conflict", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + path: []string{"$.Extra[*]", "$.Extra[1]"}, + err: `onflicts with previously-set all (*) index`, + }, + }, + { + name: "key conflict", + args: args{ + IDL: baseIDL, + rootStruct: "ExtraInfo", + path: []string{"$.IntMap{*}", "$.IntMap{1}"}, + err: `onflicts with previously-set all (*) keys`, + }, + }, + { + name: "empty map set", + args: args{ + IDL: baseIDL, + rootStruct: "ExtraInfo", + path: []string{"$.IntMap{}"}, + err: `empty key set`, + }, + }, + { + name: "empty list set", + args: args{ + IDL: baseIDL, + rootStruct: "ExtraInfo", + path: []string{"$.List[]"}, + err: `empty index set`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st := GetDescriptor(tt.args.IDL, tt.args.rootStruct) + _, err := GetFieldMask(st, tt.args.path...) + if err == nil || !strings.Contains(err.Error(), tt.args.err) { + t.Fatal(err) + } + }) + } +} + +func BenchmarkNewFieldMask(b *testing.B) { + st := GetDescriptor(baseIDL, "Base") + if st == nil { + b.Fail() + } + b.Run("new", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + fm, err := NewFieldMask(st, []string{"$.LogID", "$.TrafficEnv.Open", "$.TrafficEnv.Env", "$.Extra[0]", "$.Extra[1].IntMap{0}", "$.Extra[2].StrMap{\"abcd\"}"}...) + if err != nil { + b.Fatal(err) + } + _ = fm + } + }) + b.Run("reuse", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + fm, err := GetFieldMask(st, []string{"$.LogID", "$.TrafficEnv.Open", "$.TrafficEnv.Env", "$.Extra[0]", "$.Extra[1].IntMap{0}", "$.Extra[2].StrMap{\"abcd\"}"}...) + if err != nil { + b.Fatal(err) + } + fm.Recycle() + } + }) +} + +func BenchmarkFieldMask_InMask(b *testing.B) { + st := GetDescriptor(baseIDL, "Base") + if st == nil { + b.Fail() + } + fm, err := NewFieldMask(st, []string{"$.Extra[0]", "$.Extra[1].IntMap{0}", "$.Extra[2].StrMap{\"abcdefghi\"}"}...) + if err != nil { + b.Fatal(err) + } + b.Run("Field", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if next, exist := fm.Field(6); !exist { + b.Fail() + } else { + _ = next + } + } + }) + + b.Run("Index", func(b *testing.B) { + var v *FieldMask + if next, ex := fm.Field(6); !ex { + b.Fail() + } else { + v = next + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if next, ex := v.Int(0); !ex { + b.Fail() + } else { + _ = next + } + } + }) + + b.Run("Int Map", func(b *testing.B) { + var v *FieldMask + if next, ex := fm.Field(6); !ex { + b.Fail() + } else if l, ex := next.Int(1); !ex { + b.Fail() + } else if f, ex := l.Field(1); !ex { + b.Fail() + } else { + v = f + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if next, ex := v.Int(0); !ex { + b.Fail() + } else { + _ = next + } + } + }) + + b.Run("Str Map", func(b *testing.B) { + var v *FieldMask + if next, ex := fm.Field(6); !ex { + b.Fail() + } else if l, ex := next.Int(2); !ex { + b.Fail() + } else if f, ex := l.Field(2); !ex { + b.Fail() + } else { + v = f + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if next, ex := v.Str("abcdefghi"); !ex { + b.Fail() + } else { + _ = next + } + } + }) +} diff --git a/fieldmask/mapper.go b/fieldmask/mapper.go new file mode 100644 index 00000000..8b0c1f03 --- /dev/null +++ b/fieldmask/mapper.go @@ -0,0 +1,38 @@ +/** + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fieldmask + +// ValueMapper is used to mapping values +// type ValueMapper struct { +// // AllowNested indicates if a recursive type (LIST/SET/MAP) is acceptable to this. +// // If it is true, every elem (both key for MAP) will trigger `OnXX` mapping function +// AllowRecurse bool + +// //mapping functions +// OnInt func(isNil bool, val int) (int, bool) +// OnFloat func(isNil bool, val float64) (int, bool) +// OnBool func(isNil bool, val bool) (bool, bool) +// OnString func(isNil bool, val string) (string, bool) +// } + +// PathNapper is the definition of a ValueMapper for specific path +// type PathMapper struct { +// Path string +// Mapper ValueMapper +// } + +// func NewFieldMapper(desc thrift_reflection.TypeDescriptor, maps ...PathMapper) *FieldMask diff --git a/fieldmask/mask.go b/fieldmask/mask.go new file mode 100644 index 00000000..425d681b --- /dev/null +++ b/fieldmask/mask.go @@ -0,0 +1,186 @@ +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fieldmask + +import ( + "fmt" + "strings" + "sync" + + "github.com/cloudwego/thriftgo/thrift_reflection" +) + +type fieldMaskType uint8 + +const ( + ftInvalid fieldMaskType = iota + ftScalar + ftArray + ftStruct + ftStrMap + ftIntMap +) + +// FieldMask represents a collection of thrift pathes +// See +type FieldMask struct { + typ fieldMaskType + + isAll bool + + all *FieldMask + + fdMask *fieldMap + + strMask strMap + + intMask intMap +} + +var fmsPool = sync.Pool{ + New: func() interface{} { + return &FieldMask{} + }, +} + +// NewFieldMask create a new fieldmask +func NewFieldMask(desc *thrift_reflection.TypeDescriptor, pathes ...string) (*FieldMask, error) { + ret := FieldMask{} + err := ret.init(desc, pathes...) + if err != nil { + return nil, err + } + return &ret, nil +} + +// GetFieldMask reuse fieldmask from pool +func GetFieldMask(desc *thrift_reflection.TypeDescriptor, paths ...string) (*FieldMask, error) { + ret := fmsPool.Get().(*FieldMask) + err := ret.init(desc, paths...) + if err != nil { + return nil, err + } + return ret, nil +} + +// Recycle puts fieldmask into pool +func (self *FieldMask) Recycle() { + self.reset() + fmsPool.Put(self) +} + +// reset clears fieldmask's all path +func (self *FieldMask) reset() { + if self == nil { + return + } + self.isAll = false + self.typ = 0 + self.fdMask.Reset() + self.intMask.Reset() + self.strMask.Reset() +} + +func (self *FieldMask) init(desc *thrift_reflection.TypeDescriptor, paths ...string) error { + // horizontal traversal... + for _, path := range paths { + if err := self.addPath(path, desc); err != nil { + return fmt.Errorf("Parsing path %q error: %v", path, err) + } + } + return nil +} + +// String pretty prints the structure a FieldMask represents +// +// For example: +// pathes `[]string{"$.Extra[0].List", "$.Extra[*].Set", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr"}` will print: +// +// (Base) +// .Extra (list) +// [ +// * +// ] +// .Meta (MetaInfo) +// .F2 (map) +// { +// * +// } +// +// WARING: This is unstable API, the printer format is not guaranteed +func (self FieldMask) String(desc *thrift_reflection.TypeDescriptor) string { + buf := strings.Builder{} + buf.WriteString("(") + buf.WriteString(desc.GetName()) + buf.WriteString(")\n") + self.print(&buf, 0, desc) + return buf.String() +} + +// Exist tells if the fieldmask is setted +func (self *FieldMask) Exist() bool { + return self != nil && self.typ != 0 +} + +// Field returns the specific sub mask for a given id, and tells if the id in the mask +func (self *FieldMask) Field(id int16) (*FieldMask, bool) { + if self == nil || self.typ == 0 { + return nil, true + } + if self.isAll { + return self.all, true + } + fm := self.fdMask.Get(fieldID(id)) + return fm, fm != nil +} + +// Int returns the specific sub mask for a given index, and tells if the index in the mask +func (self *FieldMask) Int(id int) (*FieldMask, bool) { + if self == nil || self.typ == 0 { + return nil, true + } + if self.isAll { + return self.all, true + } + fm := self.intMask.Get(id) + return fm, fm != nil +} + +// Field returns the specific sub mask for a given string, and tells if the string in the mask +func (self *FieldMask) Str(id string) (*FieldMask, bool) { + if self == nil || self.typ == 0 { + return nil, true + } + if self.isAll { + return self.all, true + } + fm := self.strMask.Get(id) + return fm, fm != nil +} + +// All tells if the mask allows all elements pass (*) +func (self *FieldMask) All() bool { + if self == nil { + return true + } + switch self.typ { + case ftStruct, ftArray, ftIntMap, ftStrMap: + return self.isAll + default: + return true + } +} diff --git a/fieldmask/path.go b/fieldmask/path.go new file mode 100644 index 00000000..c3dc0734 --- /dev/null +++ b/fieldmask/path.go @@ -0,0 +1,512 @@ +/** + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fieldmask + +import ( + "fmt" + "io" + "strconv" + "unsafe" + + "github.com/cloudwego/thriftgo/thrift_reflection" +) + +type pathType int + +const ( + pathTypeLitStr pathType = 1 + iota + pathTypeLitInt pathType = 1 + iota + pathTypeStr + pathTypeRoot + pathTypeField + pathTypeIndexL + pathTypeIndexR + pathTypeMapL + pathTypeMapR + pathTypeElem + pathTypeAny + + pathTypeEOF pathType = -1 + pathTypeERR pathType = -2 +) + +type pathSep byte + +const ( + pathSepRoot = '$' + pathSepField = '.' + pathSepIndexLeft = '[' + pathSepIndexRight = ']' + pathSepMapLeft = '{' + pathSepMapRight = '}' + pathSepElem = ',' + pathSepAny = '*' + pathSepQuote = '"' + pathSepSlash = '\\' +) + +type pathValue struct { + pv unsafe.Pointer + iv int +} + +func newPathValueStr(val string) pathValue { + if val == "" { + return pathValue{iv: len(val), pv: nil} + } else { + return pathValue{iv: len(val), pv: *(*unsafe.Pointer)(unsafe.Pointer(&val))} + } +} + +func newPathValueInt(val int) pathValue { + return pathValue{iv: val} +} + +func (v pathValue) Str() string { + return *(*string)(unsafe.Pointer(&v)) +} + +func (v pathValue) Int() int { + return v.iv +} + +type pathToken struct { + typ pathType + val pathValue + loc [2]int +} + +func (p pathToken) Type() pathType { + return p.typ +} + +// func (p pathToken) ToInt() (int, bool) { +// if p.typ == pathTypeLitStr || p.typ == pathTypeStr { +// i, e := strconv.ParseInt(p.val.Str(), 10, 64) +// if e != nil { +// return 0, false +// } +// return int(i), true +// } else if p.typ == pathTypeLitInt { +// return p.val.Int(), true +// } else { +// return 0, false +// } +// } + +// func (p pathToken) ToStr() (string, bool) { +// if p.typ == pathTypeLitStr || p.typ == pathTypeStr { +// return p.val.Str(), true +// } else if p.typ == pathTypeLitInt { +// str := strconv.Itoa(p.val.Int()) +// return str, true +// } else { +// return "", false +// } +// } + +func (p pathToken) Pos() (int, int) { + return p.loc[0], p.loc[1] +} + +func (p pathToken) Err() error { + switch p.typ { + case pathTypeEOF: + return io.EOF + default: + return nil + } +} + +func (p pathToken) String() string { + switch p.typ { + case pathTypeEOF: + return fmt.Sprintf("EOF at %d", p.loc[0]) + case pathTypeAny: + return fmt.Sprintf("* at %d", p.loc[0]) + case pathTypeElem: + return fmt.Sprintf(", at %d", p.loc[0]) + case pathTypeField: + return fmt.Sprintf(". at %d", p.loc[0]) + case pathTypeRoot: + return fmt.Sprintf("$ at %d", p.loc[0]) + case pathTypeIndexL: + return fmt.Sprintf("[ at %d", p.loc[0]) + case pathTypeIndexR: + return fmt.Sprintf("] at %d", p.loc[0]) + case pathTypeMapL: + return fmt.Sprintf("{ at %d", p.loc[0]) + case pathTypeMapR: + return fmt.Sprintf("} at %d", p.loc[0]) + // case pathTypeLitInt: + // return fmt.Sprintf("%d(%d:%d)", p.val.Int(), p.loc[0], p.loc[1]) + case pathTypeLitStr: + return fmt.Sprintf("Lit(%s) at %d-%d", p.val.Str(), p.loc[0], p.loc[1]) + case pathTypeLitInt: + return fmt.Sprintf("Lit(%d) at %d-%d", p.val.Int(), p.loc[0], p.loc[1]) + case pathTypeStr: + return fmt.Sprintf("Str(%q) at %d-%d", p.val.Str(), p.loc[0], p.loc[1]) + case pathTypeERR: + return fmt.Sprintf("Err(%s) at %d-%d", p.val.Str(), p.loc[0], p.loc[1]) + default: + return fmt.Sprintf("UnknownToken(%d) at %d:%d", p.typ, p.loc[0], p.loc[1]) + } +} + +func newPathToken(typ pathType, val string, s, e int) pathToken { + switch typ { + case pathTypeEOF: + return pathToken{typ: typ} + case pathTypeStr, pathTypeAny, pathTypeElem, pathTypeField, pathTypeIndexL, pathTypeIndexR, pathTypeLitStr, pathTypeMapR, pathTypeMapL, pathTypeRoot: + return pathToken{typ: typ, val: newPathValueStr(val), loc: [2]int{s, e}} + case pathTypeLitInt: + i, err := strconv.Atoi(val) + if err != nil { + panic(err) + } + return pathToken{typ: typ, val: newPathValueInt(i), loc: [2]int{s, e}} + default: + panic("unspported pathType " + val) + } +} + +type pathIterator struct { + pos int + src string +} + +func newPathIter(src string) pathIterator { + return pathIterator{src: src, pos: 0} +} + +func (p *pathIterator) Pos() int { + return p.pos +} + +func (p *pathIterator) LeftPath() string { + if p.pos >= len(p.src) { + return "" + } + return p.src[p.pos:] +} + +func (p *pathIterator) HasNext() bool { + return p.pos < len(p.src) +} + +func (p *pathIterator) Next() pathToken { + if !p.HasNext() { + return newPathToken(pathTypeEOF, "", p.pos, p.pos) + } + s := p.Pos() + c := p.char() + switch c { + case pathSepRoot: + return newPathToken(pathTypeRoot, "", s, p.Pos()) + case pathSepField: + return newPathToken(pathTypeField, "", s, p.Pos()) + case pathSepIndexLeft: + return newPathToken(pathTypeIndexL, "", s, p.Pos()) + case pathSepIndexRight: + return newPathToken(pathTypeIndexR, "", s, p.Pos()) + case pathSepMapLeft: + return newPathToken(pathTypeMapL, "", s, p.Pos()) + case pathSepMapRight: + return newPathToken(pathTypeMapR, "", s, p.Pos()) + case pathSepElem: + return newPathToken(pathTypeElem, "", s, p.Pos()) + case pathSepAny: + return newPathToken(pathTypeAny, "", s, p.Pos()) + case pathSepQuote: + p.Unwind(s) + v, e := p.str() + if e != nil { + return newPathToken(pathTypeERR, "invalid quote string", s, p.Pos()) + } + return newPathToken(pathTypeStr, v, s, p.Pos()) + default: + p.Unwind(s) + val, isInt := p.lit() + if isInt { + return newPathToken(pathTypeLitInt, val, s, p.Pos()) + } + return newPathToken(pathTypeLitStr, val, s, p.Pos()) + } +} + +func (p *pathIterator) char() byte { + c := p.src[p.pos] + p.pos += 1 + return c +} + +func (p *pathIterator) Unwind(pos int) { + p.pos = pos +} + +func (p *pathIterator) lit() (string, bool) { + i := p.pos + var isInt bool + for ; i < len(p.src); i++ { + switch cc := p.src[i]; cc { + case pathSepElem, pathSepAny, pathSepRoot, pathSepField, pathSepIndexLeft, pathSepIndexRight, pathSepMapLeft, pathSepMapRight, pathSepQuote, pathSepSlash: + goto ret + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + if i == p.pos { + isInt = true + } else { + isInt = isInt && true + } + default: + isInt = false + } + } +ret: + val := p.src[p.pos:i] + p.pos = i + return val, isInt +} + +func (p *pathIterator) str() (string, error) { + i := p.pos + open := false + for ; i < len(p.src); i++ { + switch cc := p.src[i]; cc { + case pathSepSlash: + i += 1 + case pathSepQuote: + open = !open + if !open { + i += 1 + goto ret + } + } + } +ret: + val := p.src[p.pos:i] + p.pos = i + val, err := strconv.Unquote(val) + if err != nil { + return "", err + } + return val, nil +} + +// PathInMask tells if a given path is already in current fieldmask +func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path string) bool { + it := newPathIter(path) + // println("[PathInMask]") + for it.HasNext() { + // NOTICE: desc shoudn't empty here + // println("desc: ", curDesc.Name) + + // NOTICE: empty fm for path means **IN MASK** + if cur == nil { + return true + } + + stok := it.Next() + if stok.Err() != nil { + return false + } + styp := stok.Type() + // println("stoken: ", stok.String()) + + if styp == pathTypeRoot { + continue + } else if styp == pathTypeField { + // get struct descriptor + st, err := curDesc.GetStructDescriptor() + if err != nil { + return false + } + // println("struct: ", st.Name) + if cur.typ != ftStruct { + return false + } + + tok := it.Next() + if tok.Err() != nil { + return false + } + typ := tok.Type() + // println("token", tok.String()) + + var f *thrift_reflection.FieldDescriptor + if typ == pathTypeLitInt { + f = st.GetFieldById(int32(tok.val.Int())) + if f == nil { + return false + } + + } else if typ == pathTypeLitStr { + name := tok.val.Str() + f = st.GetFieldByName(name) + if f == nil { + return false + } + } else if typ == pathTypeAny { + if !cur.All() { + return false + } + } else { + return false + } + + // println("all", all, "FieldInMask:", cur.FieldInMask(int32(f.GetID()))) + // check if name set mask + nextFm, exist := cur.Field(int16(f.GetID())) + if !exist { + return false + } + + // deep to next desc + curDesc = f.GetType() + if curDesc == nil { + return false + } + cur = nextFm + + } else if styp == pathTypeIndexL { + + // get element desc + if !curDesc.IsList() { + return false + } + et := curDesc.GetValueType() + if et == nil { + return false + } + + if cur.typ != ftArray { + return false + } + + all := cur.All() + next := cur.all + // iter indexies... + for it.HasNext() { + tok := it.Next() + typ := tok.Type() + // println("token", tok.String()) + if tok.Err() != nil { + return false + } + + if typ == pathTypeIndexR { + break + } + if all || typ == pathTypeElem { + continue + } + if typ == pathTypeAny { + return false + } + if typ != pathTypeLitInt { + return false + } + + // check mask + v := tok.val.Int() + nextFm, exist := cur.Int(v) + if !exist { + return false + } + // NOTICE: always use last elem's fieldmask + next = nextFm + } + + // next fieldmask + curDesc = et + cur = next + + } else if styp == pathTypeMapL { + // get element and key desc + if !curDesc.IsMap() { + return false + } + et := curDesc.GetValueType() + if et == nil { + return false + } + kt := curDesc.GetKeyType() + if kt == nil { + return false + } + + // println("cur.typ::", cur.typ, "cur::", cur.String(curDesc)) + if cur.typ != ftIntMap && cur.typ != ftStrMap { + return false + } + + next := cur.all + // iter indexies... + for it.HasNext() { + tok := it.Next() + typ := tok.Type() + if tok.Err() != nil { + return false + } + // println("token", tok.String()) + + if typ == pathTypeMapR { + break + } + if cur.All() || typ == pathTypeElem { + continue + } + if typ == pathTypeAny { + return false + } + + if typ == pathTypeLitInt { + if cur.typ != ftIntMap { + return false + } + v := tok.val.Int() + nextFm, exist := cur.Int(v) + if !exist { + return false + } + // NOTICE: always use last elem's fieldmask + next = nextFm + } else if typ == pathTypeStr { + if cur.typ != ftStrMap { + return false + } + v := tok.val.Str() + nextFm, exist := cur.Str(v) + if !exist { + return false + } + // NOTICE: always use last elem's fieldmask + next = nextFm + } else { + return false + } + } + + // next fieldmask + curDesc = et + cur = next + } else { + return false + } + } + + return !it.HasNext() +} diff --git a/fieldmask/storage.go b/fieldmask/storage.go new file mode 100644 index 00000000..8224a6db --- /dev/null +++ b/fieldmask/storage.go @@ -0,0 +1,243 @@ +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fieldmask + +import ( + "github.com/cloudwego/thriftgo/thrift_reflection" +) + +type fieldID int32 + +const _MaxFieldIDHead = 127 + +type fieldMap struct { + head [_MaxFieldIDHead + 1]*FieldMask + tail map[fieldID]*FieldMask +} + +func makeFieldMaskMap(st *thrift_reflection.StructDescriptor) fieldMap { + max := 0 + count := 0 + for _, f := range st.GetFields() { + if max < int(f.GetID()) { + max = int(f.GetID()) + count = 0 + } else { + count += 1 + } + } + return fieldMap{ + tail: make(map[fieldID]*FieldMask, count), + } +} + +func (fm *fieldMap) Reset() { + if fm == nil { + return + } + for _, v := range fm.tail { + v.reset() + } + // memclrNoHeapPointers(unsafe.Pointer(&fm.head), 8*(_MaxFieldIDHead+1)) + for _, v := range fm.head { + v.reset() + } +} + +// func (self *fieldMap) Reset() { +// if self == nil { +// return +// } +// self.tail = self.tail[:0] +// } + +func (self *fieldMap) SetIfNotExist(f fieldID, ft fieldMaskType) (s *FieldMask) { + if f <= _MaxFieldIDHead { + s = self.head[f] + if s == nil { + fm := newFieldMask(ft) + self.head[f] = &fm + return &fm + } + + } else { + s = self.tail[f] + if s == nil { + fm := newFieldMask(ft) + self.tail[f] = &fm + return &fm + } + } + if s.typ == 0 { + s.assign(ft) + } + return s +} + +func (self *fieldMap) Get(f fieldID) (ret *FieldMask) { + if f <= _MaxFieldIDHead { + ret = self.head[f] + } else { + ret = self.tail[f] + } + if ret.Exist() { + return ret + } + return nil +} + +// setFieldID ensure a fieldmask slot for f +func (self *FieldMask) setFieldID(f fieldID, st *thrift_reflection.StructDescriptor) *FieldMask { + if self.fdMask == nil { + // println("new fdmask") + m := makeFieldMaskMap(st) + self.fdMask = &m + } + return self.fdMask.SetIfNotExist(fieldID(f), switchFt(st.GetFieldById(int32(f)).GetType())) +} + +// type fieldMaskBitmap []byte + +// const _BucketBit = 8 + +// func (self *fieldMaskBitmap) Set(f fieldID) { +// b := int(f / _BucketBit) +// i := int(f % _BucketBit) +// c := cap(*self) +// if c <= b+1 { +// tmp := make([]byte, len(*self), (c + b + 1)) +// copy(tmp, *self) +// *self = tmp +// } +// if len(*self) <= b { +// *self = (*self)[:b+1] +// } +// (*self)[b] |= byte(1 << i) +// } + +// func (self *fieldMaskBitmap) Get(f fieldID) bool { +// b := int(f / _BucketBit) +// if len(*self) <= b { +// return false +// } +// i := int(f % _BucketBit) +// return ((*self)[b] & byte(1<") + } + if field.GetType().IsMap() { + buf.WriteString("<") + buf.WriteString(field.GetType().GetKeyType().GetName()) + buf.WriteString(",") + buf.WriteString(field.GetType().GetValueType().GetName()) + buf.WriteString(">") + } + buf.WriteString(")\n") + nd := field.GetType() + next, exist := self.Field(int16(field.GetID())) + if exist { + next.print(buf, indent, nd) + } +} + +func (self *FieldMask) printElem(buf *strings.Builder, indent int, id interface{}, desc *thrift_reflection.TypeDescriptor) { + printIndent(buf, indent, "+") + var next *FieldMask + var e bool + switch v := id.(type) { + case int: + buf.WriteString(strconv.Itoa(v)) + next, e = self.Int(v) + case string: + buf.WriteString(v) + next, e = self.Str(v) + } + buf.WriteString("\n") + if e { + next.print(buf, indent, desc) + } +} diff --git a/generator/golang/imports.go b/generator/golang/imports.go index a4b32456..70ebd923 100644 --- a/generator/golang/imports.go +++ b/generator/golang/imports.go @@ -87,6 +87,7 @@ func (im *importManager) init(cu *CodeUtils, ast *parser.Thrift) { "unknown": DefaultUnknownLib, "meta": DefaultMetaLib, "thrift_reflection": ThriftReflectionLib, + "fieldmask": ThriftFieldMaskLib, } for pkg, path := range std { ns.Add(pkg, path) diff --git a/generator/golang/option.go b/generator/golang/option.go index 1d53738b..0bbb4943 100644 --- a/generator/golang/option.go +++ b/generator/golang/option.go @@ -55,6 +55,7 @@ type Features struct { CodeRef bool `code_ref:"Genenerate code ref by given idl-ref.yaml"` KeepCodeRefName bool `keep_code_ref_name:"Genenerate code ref but still keep file name."` TrimIDL bool `trim_idl:"Simplify IDL to the most concise form before generating code."` + WithFieldMask bool `with_field_mask:"Support field-mask for generated code."` } var defaultFeatures = Features{ @@ -85,6 +86,7 @@ var defaultFeatures = Features{ GenerateReflectionInfo: false, EnumAsINT32: false, TrimIDL: false, + WithFieldMask: false, } type param struct { diff --git a/generator/golang/read_write_context.go b/generator/golang/read_write_context.go index c4408d76..c23c2e0f 100644 --- a/generator/golang/read_write_context.go +++ b/generator/golang/read_write_context.go @@ -37,6 +37,8 @@ type ReadWriteContext struct { NeedDecl bool // Whether a declaration of target is needed ids map[string]int // Prefix => local variable index + + FieldMask string } // GenID returns a local variable with the given name as prefix. @@ -55,6 +57,12 @@ func (c *ReadWriteContext) WithDecl() *ReadWriteContext { return c } +// WithDecl claims that the context needs a variable declaration. +func (c *ReadWriteContext) WithFieldMask(fm string) *ReadWriteContext { + c.FieldMask = fm + return c +} + // WithTarget sets the target name. func (c *ReadWriteContext) WithTarget(t string) *ReadWriteContext { c.Target = t @@ -119,5 +127,7 @@ func mkRWCtx(r *Resolver, s *Scope, t *parser.Type, top *ReadWriteContext) (*Rea return nil, err } } + + ctx.FieldMask = "fm" return ctx, nil } diff --git a/generator/golang/templates/init.go b/generator/golang/templates/init.go index 5cb07fdd..a4fa05f4 100644 --- a/generator/golang/templates/init.go +++ b/generator/golang/templates/init.go @@ -27,6 +27,7 @@ func Alternative() map[string][]string { func Templates() []string { return []string{ File, Imports, Constant, Enum, Typedef, + HandleUnknownFields, StructLike, StructLikeDefault, StructLikeRead, diff --git a/generator/golang/templates/reflection/reflection_tpl.go b/generator/golang/templates/reflection/reflection_tpl.go index 0a017afe..eb2331ac 100644 --- a/generator/golang/templates/reflection/reflection_tpl.go +++ b/generator/golang/templates/reflection/reflection_tpl.go @@ -77,22 +77,51 @@ func GetFileDescriptorFor{{ToCamel $IDLName}}() *thrift_reflection.FileDescripto func (p *{{.GoName}}) GetDescriptor() *thrift_reflection.StructDescriptor{ return file_{{$IDLName}}_thrift.GetStructDescriptor("{{.Name}}") } + +func (p *{{.GoName}}) GetTypeDescriptor() *thrift_reflection.TypeDescriptor{ + ret := thrift_reflection.NewTypeDescriptor() + ret.Filepath = file_{{$IDLName}}_thrift.Filepath + ret.Name = "{{.Name}}" + return ret +} {{- end}} {{- range .Enums}} func (p {{.GoName}}) GetDescriptor() *thrift_reflection.EnumDescriptor{ return file_{{$IDLName}}_thrift.GetEnumDescriptor("{{.Name}}") } + +func (p *{{.GoName}}) GetTypeDescriptor() *thrift_reflection.TypeDescriptor{ + ret := thrift_reflection.NewTypeDescriptor() + ret.Filepath = file_{{$IDLName}}_thrift.Filepath + ret.Name = "{{.Name}}" + return ret +} {{- end}} {{- range .Unions}} func (p *{{.GoName}}) GetDescriptor() *thrift_reflection.StructDescriptor{ return file_{{$IDLName}}_thrift.GetUnionDescriptor("{{.Name}}") } + +func (p *{{.GoName}}) GetTypeDescriptor() *thrift_reflection.TypeDescriptor{ + ret := thrift_reflection.NewTypeDescriptor() + ret.Filepath = file_{{$IDLName}}_thrift.Filepath + ret.Name = "{{.Name}}" + return ret +} {{- end}} {{- range .Exceptions}} func (p *{{.GoName}}) GetDescriptor() *thrift_reflection.StructDescriptor{ return file_{{$IDLName}}_thrift.GetExceptionDescriptor("{{.Name}}") } + +func (p *{{.GoName}}) GetTypeDescriptor() *thrift_reflection.TypeDescriptor{ + ret := thrift_reflection.NewTypeDescriptor() + ret.Filepath = file_{{$IDLName}}_thrift.Filepath + ret.Name = "{{.Name}}" + return ret +} {{- end}} + {{- InsertionPoint "eof"}} ` diff --git a/generator/golang/templates/struct.go b/generator/golang/templates/struct.go index d1fe3481..6b02c23c 100644 --- a/generator/golang/templates/struct.go +++ b/generator/golang/templates/struct.go @@ -32,6 +32,10 @@ type {{$TypeName}} struct { {{- UseStdLibrary "unknown"}} _unknownFields unknown.Fields {{- end}} + {{- if Features.WithFieldMask}} + {{- UseStdLibrary "fieldmask"}} + _fieldmask *fieldmask.FieldMask + {{- end}} } {{- if Features.GenerateTypeMeta}} @@ -77,6 +81,15 @@ func (p *{{$TypeName}}) CarryingUnknownFields() bool { } {{end}}{{/* if Features.KeepUnknownFields */}} +{{if Features.WithFieldMask}} +func (p *{{$TypeName}}) GetFieldMask() *fieldmask.FieldMask { + return p._fieldmask +} +func (p *{{$TypeName}}) SetFieldMask(fm *fieldmask.FieldMask) { + p._fieldmask = fm +} +{{end}}{{/* if Features.WithFieldMask */}} + var fieldIDToName_{{$TypeName}} = map[int16]string{ {{- range .Fields}} {{.ID}}: "{{.Name}}", @@ -156,37 +169,35 @@ func (p *{{$TypeName}}) Read(iprot thrift.TProtocol) (err error) { {{if or (gt (len .Fields) 0) Features.KeepUnknownFields}} switch fieldId { {{- range .Fields}} + {{- $isBaseVal := .Type | IsBaseType}} case {{.ID}}: if fieldTypeId == thrift.{{.Type | GetTypeIDConstant }} { - if err = p.{{.Reader}}(iprot); err != nil { + {{- if Features.WithFieldMask}} + if {{if $isBaseVal}}_{{else}}nfm{{end}}, ex := p._fieldmask.Field(fieldId); ex { + {{- end}} + if err = p.{{.Reader}}(iprot{{if and Features.WithFieldMask (not $isBaseVal)}}, nfm{{end}}); err != nil { goto ReadFieldError } {{- if .Requiredness.IsRequired}} isset{{.GoName}} = true {{- end}} - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError + break + {{- if Features.WithFieldMask}} } + {{- end}} + } + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } {{- end}}{{/* range .Fields */}} default: - {{- if Features.KeepUnknownFields}} - if err = p._unknownFields.Append(iprot, name, fieldTypeId, fieldId); err != nil { - goto UnknownFieldsAppendError - } - {{- else}} - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - {{- end}}{{/* if Features.KeepUnknownFields */}} + {{- template "HandleUnknownFields"}} } {{- else -}} if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldTypeError } {{- end}}{{/* if len(.Fields) > 0 */}} - if err = iprot.ReadFieldEnd(); err != nil { goto ReadFieldEndError } @@ -239,6 +250,20 @@ RequiredFieldNotSetError: {{- end}}{{/* define "StructLikeRead" */}} ` +var HandleUnknownFields = ` +{{define "HandleUnknownFields"}} +{{- if Features.KeepUnknownFields}} +if err = p._unknownFields.Append(iprot, name, fieldTypeId, fieldId); err != nil { + goto UnknownFieldsAppendError +} +{{- else}} +if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError +} +{{- end}}{{/* if Features.KeepUnknownFields */}} +{{- end}}{{/* define "HandleUnknownFields" */}} +` + // StructLikeReadField . var StructLikeReadField = ` {{define "StructLikeReadField"}} @@ -246,8 +271,9 @@ var StructLikeReadField = ` {{- $TypeName := .GoName}} {{- range .Fields}} {{$FieldName := .GoName}} -func (p *{{$TypeName}}) {{.Reader}}(iprot thrift.TProtocol) error { - {{- $ctx := MkRWCtx .}} +{{- $isBaseVal := .Type | IsBaseType -}} +func (p *{{$TypeName}}) {{.Reader}}(iprot thrift.TProtocol{{if and Features.WithFieldMask (not $isBaseVal)}}, fm *fieldmask.FieldMask{{end}}) error { + {{$ctx := (MkRWCtx .).WithFieldMask "fm"}} {{- template "FieldRead" $ctx}} return nil } @@ -275,12 +301,19 @@ func (p *{{$TypeName}}) Write(oprot thrift.TProtocol) (err error) { } if p != nil { {{- range .Fields}} - if err = p.{{.Writer}}(oprot); err != nil { + {{- $isBaseVal := .Type | IsBaseType}} + {{- if Features.WithFieldMask}} + if {{if $isBaseVal}}_{{else}}nfm{{end}}, ex := p._fieldmask.Field({{.ID}}); ex { + {{- end}} + if err = p.{{.Writer}}(oprot{{if and Features.WithFieldMask (not $isBaseVal)}}, nfm{{end}}); err != nil { fieldId = {{.ID}} goto WriteFieldError } + {{- if Features.WithFieldMask}} + } {{- end}} - {{if Features.KeepUnknownFields}} + {{- end}}{{/* range .Fields */}} + {{- if Features.KeepUnknownFields}} if err = p._unknownFields.Write(oprot); err != nil { goto UnknownFieldsWriteError } @@ -324,14 +357,15 @@ var StructLikeWriteField = ` {{- $FieldName := .GoName}} {{- $IsSetName := .IsSetter}} {{- $TypeID := .Type | GetTypeIDConstant }} -func (p *{{$TypeName}}) {{.Writer}}(oprot thrift.TProtocol) (err error) { +{{- $isBaseVal := .Type | IsBaseType -}} +func (p *{{$TypeName}}) {{.Writer}}(oprot thrift.TProtocol{{if and Features.WithFieldMask (not $isBaseVal)}}, fm *fieldmask.FieldMask{{end}}) (err error) { {{- if .Requiredness.IsOptional}} if p.{{$IsSetName}}() { {{- end}} if err = oprot.WriteFieldBegin("{{.Name}}", thrift.{{$TypeID}}, {{.ID}}); err != nil { goto WriteFieldBeginError } - {{- $ctx := MkRWCtx .}} + {{- $ctx := (MkRWCtx .).WithFieldMask "fm"}} {{- template "FieldWrite" $ctx}} if err = oprot.WriteFieldEnd(); err != nil { goto WriteFieldEndError @@ -465,6 +499,7 @@ var FieldRead = ` var FieldReadStructLike = ` {{define "FieldReadStructLike"}} {{- .Target}} {{if .NeedDecl}}:{{end}}= {{.TypeName.Deref.NewFunc}}() + {{if Features.WithFieldMask}}{{.Target}}.SetFieldMask({{.FieldMask}}){{end}} if err := {{.Target}}.Read(iprot); err != nil { return err } @@ -515,6 +550,10 @@ var FieldReadContainer = ` // FieldReadMap . var FieldReadMap = ` {{define "FieldReadMap"}} +{{- $isIntKey := .KeyCtx.Type | IsIntType -}} +{{- $isStrKey := .KeyCtx.Type | IsStrType -}} +{{- $isBaseVal := .ValCtx.Type | IsBaseType -}} +{{- $curFieldMask := .FieldMask -}} _, _, size, err := iprot.ReadMapBegin() if err != nil { return err @@ -524,9 +563,30 @@ var FieldReadMap = ` {{- $key := .GenID "_key"}} {{- $ctx := .KeyCtx.WithDecl.WithTarget $key}} {{- template "FieldRead" $ctx}} + {{- if Features.WithFieldMask}} + {{- if $isIntKey}} + {{- $curFieldMask = "nfm"}} + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(int({{$key}})); !ex { + if err := iprot.Skip(thrift.{{.ValCtx.Type | GetTypeIDConstant}}); err != nil { + return err + } + continue + } else { + {{- else if $isStrKey}} + {{- $curFieldMask = "nfm"}} + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Str(string({{$key}})); !ex { + if err := iprot.Skip(thrift.{{.ValCtx.Type | GetTypeIDConstant}}); err != nil { + return err + } + continue + } else { + {{- else}} + {{$curFieldMask}} = nil + {{- end}} + {{- end}}{{/* end WithFieldMask */}} {{/* line break */}} {{- $val := .GenID "_val"}} - {{- $ctx := .ValCtx.WithDecl.WithTarget $val}} + {{- $ctx := (.ValCtx.WithDecl.WithTarget $val).WithFieldMask $curFieldMask}} {{- template "FieldRead" $ctx}} {{if and .ValCtx.Type.Category.IsStructLike Features.ValueTypeForSIC}} @@ -534,6 +594,9 @@ var FieldReadMap = ` {{end}} {{.Target}}[{{$key}}] = {{$val}} + {{- if and Features.WithFieldMask (or $isIntKey $isStrKey)}} + } + {{- end}} } if err := iprot.ReadMapEnd(); err != nil { return err @@ -544,6 +607,8 @@ var FieldReadMap = ` // FieldReadSet . var FieldReadSet = ` {{define "FieldReadSet"}} +{{- $isBaseVal := .ValCtx.Type | IsBaseType -}} +{{- $curFieldMask := .FieldMask -}} _, size, err := iprot.ReadSetBegin() if err != nil { return err @@ -551,14 +616,26 @@ var FieldReadSet = ` {{.Target}} {{if .NeedDecl}}:{{end}}= make({{.TypeName}}, 0, size) for i := 0; i < size; i++ { {{- $val := .GenID "_elem"}} - {{- $ctx := .ValCtx.WithDecl.WithTarget $val}} - {{- template "FieldRead" $ctx}} + {{- if Features.WithFieldMask}} + {{- $curFieldMask = "nfm"}} + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { + if err := iprot.Skip(thrift.{{.ValCtx.Type | GetTypeIDConstant}}); err != nil { + return err + } + continue + } else { + {{- end}} + {{- $ctx := (.ValCtx.WithDecl.WithTarget $val).WithFieldMask $curFieldMask}} + {{template "FieldRead" $ctx}} {{if and .ValCtx.Type.Category.IsStructLike Features.ValueTypeForSIC}} {{$val = printf "*%s" $val}} {{end}} {{.Target}} = append({{.Target}}, {{$val}}) + {{- if Features.WithFieldMask}} + } + {{- end}} } if err := iprot.ReadSetEnd(); err != nil { return err @@ -569,6 +646,8 @@ var FieldReadSet = ` // FieldReadList . var FieldReadList = ` {{define "FieldReadList"}} +{{- $isBaseVal := .ValCtx.Type | IsBaseType -}} +{{- $curFieldMask := .FieldMask -}} _, size, err := iprot.ReadListBegin() if err != nil { return err @@ -576,14 +655,26 @@ var FieldReadList = ` {{.Target}} {{if .NeedDecl}}:{{end}}= make({{.TypeName}}, 0, size) for i := 0; i < size; i++ { {{- $val := .GenID "_elem"}} - {{- $ctx := .ValCtx.WithDecl.WithTarget $val}} - {{- template "FieldRead" $ctx}} + {{- if Features.WithFieldMask}} + {{- $curFieldMask = "nfm"}} + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { + if err := iprot.Skip(thrift.{{.ValCtx.Type | GetTypeIDConstant}}); err != nil { + return err + } + continue + } else { + {{- end}} + {{- $ctx := (.ValCtx.WithDecl.WithTarget $val).WithFieldMask $curFieldMask}} + {{template "FieldRead" $ctx}} {{if and .ValCtx.Type.Category.IsStructLike Features.ValueTypeForSIC}} {{$val = printf "*%s" $val}} {{end}} {{.Target}} = append({{.Target}}, {{$val}}) + {{- if Features.WithFieldMask}} + } + {{- end}} } if err := iprot.ReadListEnd(); err != nil { return err @@ -607,6 +698,11 @@ var FieldWrite = ` // FieldWriteStructLike . var FieldWriteStructLike = ` {{define "FieldWriteStructLike"}} + {{- if Features.WithFieldMask}} + if {{.Target}} != nil { + {{.Target}}.SetFieldMask({{.FieldMask}}) + } + {{- end}} if err := {{.Target}}.Write(oprot); err != nil { return err } @@ -642,17 +738,70 @@ var FieldWriteContainer = ` // FieldWriteMap . var FieldWriteMap = ` {{define "FieldWriteMap"}} +{{- $isIntKey := .KeyCtx.Type | IsIntType -}} +{{- $isStrKey := .KeyCtx.Type | IsStrType -}} +{{- $isBaseVal := .ValCtx.Type | IsBaseType -}} +{{- $curFieldMask := .FieldMask -}} + {{- if and Features.WithFieldMask (or $isIntKey $isStrKey) }} + if !{{.FieldMask}}.All() { + l := len({{.Target}}) + for k := range {{.Target}} { + {{- if $isIntKey}} + if _, ex := {{.FieldMask}}.Int(int(k)); !ex { + l-- + } + {{- else if $isStrKey}} + if _, ex := {{.FieldMask}}.Str(string(k)); !ex { + l-- + } + {{- end}} + } + if err := oprot.WriteMapBegin(thrift. + {{- .KeyCtx.Type | GetTypeIDConstant -}} + , thrift.{{- .ValCtx.Type | GetTypeIDConstant -}} + , l); err != nil { + return err + } + } else { + if err := oprot.WriteMapBegin(thrift. + {{- .KeyCtx.Type | GetTypeIDConstant -}} + , thrift.{{- .ValCtx.Type | GetTypeIDConstant -}} + , len({{.Target}})); err != nil { + return err + } + } + {{- else}} if err := oprot.WriteMapBegin(thrift. {{- .KeyCtx.Type | GetTypeIDConstant -}} , thrift.{{- .ValCtx.Type | GetTypeIDConstant -}} , len({{.Target}})); err != nil { return err } - for k, v := range {{.Target}}{ - {{$ctx := .KeyCtx.WithTarget "k"}} + {{- end}} + for k, v := range {{.Target}} { + {{- if Features.WithFieldMask}} + {{- if $isIntKey}} + {{- $curFieldMask = "nfm"}} + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(int(k)); !ex { + continue + } else { + {{- else if $isStrKey}} + {{- $curFieldMask = "nfm"}} + ks := string(k) + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Str(ks); !ex { + continue + } else { + {{- else}} + {{$curFieldMask}} = nil + {{- end}} + {{- end}}{{/* end Features.WithFieldMask */}} + {{- $ctx := .KeyCtx.WithTarget "k" -}} {{- template "FieldWrite" $ctx}} - {{$ctx := .ValCtx.WithTarget "v"}} + {{- $ctx := (.ValCtx.WithTarget "v").WithFieldMask $curFieldMask -}} {{- template "FieldWrite" $ctx}} + {{- if and Features.WithFieldMask (or $isIntKey $isStrKey)}} + } + {{- end}} } if err := oprot.WriteMapEnd(); err != nil { return err @@ -663,11 +812,35 @@ var FieldWriteMap = ` // FieldWriteSet . var FieldWriteSet = ` {{define "FieldWriteSet"}} +{{- $isBaseVal := .ValCtx.Type | IsBaseType -}} +{{- $curFieldMask := .FieldMask -}} + {{- if Features.WithFieldMask}} + if !{{.FieldMask}}.All() { + l := len({{.Target}}) + for i:=0; i < l; i++ { + if _, ex := {{.FieldMask}}.Int(i); !ex { + l-- + } + } + if err := oprot.WriteSetBegin(thrift. + {{- .ValCtx.Type | GetTypeIDConstant -}} + , l); err != nil { + return err + } + } else { + if err := oprot.WriteSetBegin(thrift. + {{- .ValCtx.Type | GetTypeIDConstant -}} + , len({{.Target}})); err != nil { + return err + } + } + {{- else}} if err := oprot.WriteSetBegin(thrift. {{- .ValCtx.Type | GetTypeIDConstant -}} , len({{.Target}})); err != nil { return err } + {{- end}} {{- if Features.ValidateSet}} {{- $ctx := (.ValCtx.WithTarget "tgt").WithSource "src"}} for i := 0; i < len({{.Target}}); i++ { @@ -687,9 +860,18 @@ var FieldWriteSet = ` } } {{- end}} - for _, v := range {{.Target}} { - {{- $ctx := .ValCtx.WithTarget "v"}} + for {{if Features.WithFieldMask}}i{{else}}_{{end}}, v := range {{.Target}} { + {{- if Features.WithFieldMask}} + {{- $curFieldMask = "nfm"}} + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { + continue + } else { + {{- end}} + {{- $ctx := (.ValCtx.WithTarget "v").WithFieldMask $curFieldMask -}} {{- template "FieldWrite" $ctx}} + {{- if Features.WithFieldMask}} + } + {{- end}} } if err := oprot.WriteSetEnd(); err != nil { return err @@ -700,14 +882,47 @@ var FieldWriteSet = ` // FieldWriteList . var FieldWriteList = ` {{define "FieldWriteList"}} +{{- $isBaseVal := .ValCtx.Type | IsBaseType -}} +{{- $curFieldMask := .FieldMask -}} + {{- if Features.WithFieldMask}} + if !{{.FieldMask}}.All() { + l := len({{.Target}}) + for i:=0; i < l; i++ { + if _, ex := {{.FieldMask}}.Int(i); !ex { + l-- + } + } + if err := oprot.WriteListBegin(thrift. + {{- .ValCtx.Type | GetTypeIDConstant -}} + , l); err != nil { + return err + } + } else { if err := oprot.WriteListBegin(thrift. {{- .ValCtx.Type | GetTypeIDConstant -}} , len({{.Target}})); err != nil { return err } - for _, v := range {{.Target}} { - {{- $ctx := .ValCtx.WithTarget "v"}} + } + {{- else}} + if err := oprot.WriteListBegin(thrift. + {{- .ValCtx.Type | GetTypeIDConstant -}} + , len({{.Target}})); err != nil { + return err + } + {{- end}} + for {{if Features.WithFieldMask}}i{{else}}_{{end}}, v := range {{.Target}} { + {{- if Features.WithFieldMask}} + {{- $curFieldMask = "nfm"}} + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { + continue + } else { + {{- end}} + {{- $ctx := (.ValCtx.WithTarget "v").WithFieldMask $curFieldMask -}} {{- template "FieldWrite" $ctx}} + {{- if Features.WithFieldMask}} + } + {{- end}} } if err := oprot.WriteListEnd(); err != nil { return err diff --git a/generator/golang/thrift.go b/generator/golang/thrift.go index 0e49cf50..c7b8aea7 100644 --- a/generator/golang/thrift.go +++ b/generator/golang/thrift.go @@ -51,6 +51,26 @@ func IsBaseType(t *parser.Type) bool { return false } +// IsBaseType determines whether the given type is a base type. +func IsIntType(t *parser.Type) bool { + switch t.Category { + case parser.Category_Byte, parser.Category_I16, parser.Category_I32, parser.Category_I64, parser.Category_Enum: + return true + default: + return false + } +} + +// IsBaseType determines whether the given type is a base type. +func IsStrType(t *parser.Type) bool { + switch t.Category { + case parser.Category_String, parser.Category_Binary: + return true + default: + return false + } +} + // NeedRedirect deterimines whether the given field should result in a pointer type. // Condition: struct-like || (optional non-binary base type without default vlaue). func NeedRedirect(f *parser.Field) bool { diff --git a/generator/golang/util.go b/generator/golang/util.go index 6f1a6e68..327ed54e 100644 --- a/generator/golang/util.go +++ b/generator/golang/util.go @@ -43,6 +43,7 @@ const ( DefaultUnknownLib = "github.com/cloudwego/thriftgo/generator/golang/extension/unknown" DefaultMetaLib = "github.com/cloudwego/thriftgo/generator/golang/extension/meta" ThriftReflectionLib = "github.com/cloudwego/thriftgo/thrift_reflection" + ThriftFieldMaskLib = "github.com/cloudwego/thriftgo/fieldmask" ThriftOptionLib = "github.com/cloudwego/thriftgo/option" defaultTemplate = "default" ) @@ -383,6 +384,8 @@ func (cu *CodeUtils) BuildFuncMap() template.FuncMap { "IsFixedLengthType": IsFixedLengthType, "SupportIsSet": SupportIsSet, "GetTypeIDConstant": GetTypeIDConstant, + "IsIntType": IsIntType, + "IsStrType": IsStrType, "UseStdLibrary": func(libs ...string) string { cu.rootScope.imports.UseStdLibrary(libs...) return "" diff --git a/go.mod b/go.mod index dba1f689..d148f8af 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.13 require ( github.com/apache/thrift v0.13.0 - github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/dlclark/regexp2 v1.10.0 golang.org/x/text v0.6.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/internal/test_util/generator.go b/internal/test_util/generator.go new file mode 100644 index 00000000..b1b31d67 --- /dev/null +++ b/internal/test_util/generator.go @@ -0,0 +1,73 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test_util + +import ( + "github.com/cloudwego/thriftgo/generator" + "github.com/cloudwego/thriftgo/generator/backend" + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/parser" + "github.com/cloudwego/thriftgo/plugin" + "github.com/cloudwego/thriftgo/semantic" +) + +func GenerateGolang(idl string, output string, genOpts []plugin.Option, pluginOpts []*plugin.Desc) (generator.Generator, *plugin.Response) { + ast, err := parser.ParseFile(idl, nil, true) + if err != nil { + panic(err) + } + + checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) + resolver, ok := checker.(interface { + ResolveSymbols(t *parser.Thrift) error + }) + if ok { + if err = resolver.ResolveSymbols(ast); err != nil { + panic(err) + } + } + + var gen generator.Generator + if err := gen.RegisterBackend(new(golang.GoBackend)); err != nil { + panic(err) + } + + log := backend.LogFunc{ + Info: func(v ...interface{}) {}, + Warn: func(v ...interface{}) {}, + MultiWarn: func(warns []string) {}, + } + out := &generator.LangSpec{ + Language: "go", + Options: genOpts, + UsedPlugins: pluginOpts, + } + req := &plugin.Request{ + Language: out.Language, + Version: "?", + OutputPath: output, + Recursive: true, + AST: ast, + } + arg := &generator.Arguments{Out: out, Req: req, Log: log} + + res := gen.Generate(arg) + + if v := res.GetError(); v != "" { + panic(v) + } + + return gen, res +} diff --git a/test/golang/fieldmask/a.thrift b/test/golang/fieldmask/a.thrift new file mode 100644 index 00000000..5016377e --- /dev/null +++ b/test/golang/fieldmask/a.thrift @@ -0,0 +1,90 @@ +#! /bin/bash -e + +# Copyright 2022 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +namespace go base + +struct TrafficEnv { + 0: string Name = "", + 1: bool Open = false, + 2: string Env = "", + 256: i64 Code, +} + +struct Base { + 0: string Addr = "", + 1: string LogID = "", + 2: string Caller = "", + 5: optional TrafficEnv TrafficEnv, + 255: optional ExtraInfo Extra, + 256: MetaInfo Meta, +} + +struct ExtraInfo { + 1: map F1 + 2: map F2, + 3: list F3 + 4: set F4, + 5: map F5 + 6: map F6 + 7: map> F7 + 8: map> F8 + 9: map>> F9 + 10: map F10 +} + +struct MetaInfo { + 1: map IntMap, + 2: map StrMap, + 3: list List, + 4: set Set, + 11: map> MapList + 12: list>> ListMapList + 255: Base Base, +} + +typedef Val Key + +struct Val { + 1: string id +} + +typedef double Float + +typedef i64 Int + +typedef string Str + +enum Ex { + A = 1, + B = 2, + C = 3 +} + +struct BaseResp { + 1: string StatusMessage = "", + 2: i32 StatusCode = 0, + 3: optional map Extra, + + 4: map F1 + 5: map F2, + 6: list F3 + 7: set F4, + 8: map F5 + 9: map F6 + 10: map F7 + 11: map> F8 + 12: list>> F9 +} diff --git a/test/golang/fieldmask/go.mod b/test/golang/fieldmask/go.mod new file mode 100644 index 00000000..f5535937 --- /dev/null +++ b/test/golang/fieldmask/go.mod @@ -0,0 +1,21 @@ +module github.com/cloudwego/thriftgo/test/golang/fieldmask + +go 1.20 + +replace github.com/cloudwego/thriftgo => ../../../. + +replace github.com/apache/thrift => github.com/apache/thrift v0.13.0 + +require ( + github.com/apache/thrift v0.13.0 + github.com/cloudwego/thriftgo v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/text v0.6.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/test/golang/fieldmask/go.sum b/test/golang/fieldmask/go.sum new file mode 100644 index 00000000..72c5058f --- /dev/null +++ b/test/golang/fieldmask/go.sum @@ -0,0 +1,39 @@ +github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= +github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/test/golang/fieldmask/main_test.go b/test/golang/fieldmask/main_test.go new file mode 100644 index 00000000..9132762f --- /dev/null +++ b/test/golang/fieldmask/main_test.go @@ -0,0 +1,292 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fieldmask + +import ( + "testing" + + "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/thriftgo/fieldmask" + "github.com/cloudwego/thriftgo/internal/test_util" + "github.com/cloudwego/thriftgo/plugin" + nbase "github.com/cloudwego/thriftgo/test/golang/fieldmask/gen-new/base" + obase "github.com/cloudwego/thriftgo/test/golang/fieldmask/gen-old/base" + "github.com/stretchr/testify/require" +) + +func TestGen(t *testing.T) { + g, r := test_util.GenerateGolang("a.thrift", "gen-old/", nil, nil) + if err := g.Persist(r); err != nil { + panic(err) + } + g, r = test_util.GenerateGolang("a.thrift", "gen-new/", []plugin.Option{ + {"with_field_mask", ""}, + {"with_reflection", ""}, + }, nil) + if err := g.Persist(r); err != nil { + panic(err) + } +} + +func SampleNewBase() *nbase.Base { + obj := nbase.NewBase() + obj.Addr = "abcd" + obj.Caller = "abcd" + obj.LogID = "abcd" + obj.Meta = nbase.NewMetaInfo() + obj.Meta.StrMap = map[string]*nbase.Val{ + "abcd": nbase.NewVal(), + "1234": nbase.NewVal(), + } + obj.Meta.IntMap = map[int64]*nbase.Val{ + 1: nbase.NewVal(), + 2: nbase.NewVal(), + } + v0 := nbase.NewVal() + v0.ID = "a" + v1 := nbase.NewVal() + v1.ID = "b" + obj.Meta.List = []*nbase.Val{v0, v1} + obj.Meta.Set = []*nbase.Val{v0, v1} + obj.Extra = nbase.NewExtraInfo() + obj.TrafficEnv = nbase.NewTrafficEnv() + obj.TrafficEnv.Code = 1 + obj.TrafficEnv.Env = "abcd" + obj.TrafficEnv.Name = "abcd" + obj.TrafficEnv.Open = true + return obj +} + +func SampleOldBase() *obase.Base { + obj := obase.NewBase() + obj.Addr = "abcd" + obj.Caller = "abcd" + obj.LogID = "abcd" + obj.Meta = obase.NewMetaInfo() + obj.Meta.StrMap = map[string]*obase.Val{ + "abcd": obase.NewVal(), + "1234": obase.NewVal(), + } + obj.Meta.IntMap = map[int64]*obase.Val{ + 1: obase.NewVal(), + 2: obase.NewVal(), + } + v0 := obase.NewVal() + v0.ID = "a" + v1 := obase.NewVal() + v1.ID = "b" + obj.Meta.List = []*obase.Val{v0, v1} + obj.Meta.Set = []*obase.Val{v0, v1} + obj.Extra = obase.NewExtraInfo() + obj.TrafficEnv = obase.NewTrafficEnv() + obj.TrafficEnv.Code = 1 + obj.TrafficEnv.Env = "abcd" + obj.TrafficEnv.Name = "abcd" + obj.TrafficEnv.Open = true + return obj +} + +func BenchmarkWriteWithFieldMask(b *testing.B) { + b.Run("old", func(b *testing.B) { + obj := SampleOldBase() + buf := thrift.NewTMemoryBufferLen(1024) + t := thrift.NewTBinaryProtocol(buf, true, true) + + for i := 0; i < b.N; i++ { + if err := obj.Write(t); err != nil { + b.Fatal(err) + } + buf.Reset() + } + }) + + b.Run("new", func(b *testing.B) { + obj := SampleNewBase() + buf := thrift.NewTMemoryBufferLen(1024) + t := thrift.NewTBinaryProtocol(buf, true, true) + + for i := 0; i < b.N; i++ { + if err := obj.Write(t); err != nil { + b.Fatal(err) + } + buf.Reset() + } + }) + + b.Run("new-mask-half", func(b *testing.B) { + obj := SampleNewBase() + buf := thrift.NewTMemoryBufferLen(1024) + t := thrift.NewTBinaryProtocol(buf, true, true) + + fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") + if err != nil { + b.Fatal(err) + } + for i := 0; i < b.N; i++ { + obj.SetFieldMask(fm) + if err := obj.Write(t); err != nil { + b.Fatal(err) + } + buf.Reset() + } + fm.Recycle() + }) +} + +func BenchmarkReadWithFieldMask(b *testing.B) { + b.Run("old", func(b *testing.B) { + obj := SampleOldBase() + buf := thrift.NewTMemoryBufferLen(1024) + t := thrift.NewTBinaryProtocol(buf, true, true) + if err := obj.Write(t); err != nil { + b.Fatal(err) + } + data := []byte(buf.String()) + obj = obase.NewBase() + + for i := 0; i < b.N; i++ { + buf.Reset() + buf.Write(data) + if err := obj.Read(t); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("new", func(b *testing.B) { + obj := SampleNewBase() + buf := thrift.NewTMemoryBufferLen(1024) + t := thrift.NewTBinaryProtocol(buf, true, true) + if err := obj.Write(t); err != nil { + b.Fatal(err) + } + data := []byte(buf.String()) + obj = nbase.NewBase() + + for i := 0; i < b.N; i++ { + buf.Reset() + buf.Write(data) + if err := obj.Read(t); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("new-mask-half", func(b *testing.B) { + obj := SampleNewBase() + buf := thrift.NewTMemoryBufferLen(1024) + t := thrift.NewTBinaryProtocol(buf, true, true) + if err := obj.Write(t); err != nil { + b.Fatal(err) + } + data := []byte(buf.String()) + obj = nbase.NewBase() + + fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + buf.Reset() + buf.Write(data) + obj.SetFieldMask(fm) + if err := obj.Read(t); err != nil { + b.Fatal(err) + } + } + + fm.Recycle() + }) +} + +func TestFieldmaskWrite(t *testing.T) { + obj := SampleNewBase() + buf := thrift.NewTMemoryBufferLen(1024) + prot := thrift.NewTBinaryProtocol(buf, true, true) + + fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), + "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") + if err != nil { + t.Fatal(err) + } + obj.SetFieldMask(fm) + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + + obj2 := nbase.NewBase() + err = obj2.Read(prot) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, obj.Addr, obj2.Addr) + require.Equal(t, obj.LogID, obj2.LogID) + require.Equal(t, "", obj2.Caller) + require.Equal(t, "", obj2.TrafficEnv.Name) + require.Equal(t, false, obj2.TrafficEnv.Open) + require.Equal(t, "", obj2.TrafficEnv.Env) + require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) + require.Equal(t, obj.Meta.IntMap[1].ID, obj2.Meta.IntMap[1].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) + require.Equal(t, obj.Meta.StrMap["1234"].ID, obj2.Meta.StrMap["1234"].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) + require.Equal(t, "b", obj2.Meta.List[0].ID) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "b", obj2.Meta.Set[0].ID) + require.Equal(t, 1, len(obj2.Meta.Set)) + fm.Recycle() +} + +func TestFieldmaskRead(t *testing.T) { + obj := SampleNewBase() + buf := thrift.NewTMemoryBufferLen(1024) + prot := thrift.NewTBinaryProtocol(buf, true, true) + + fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), + "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") + if err != nil { + t.Fatal(err) + } + + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + + obj2 := nbase.NewBase() + obj2.SetFieldMask(fm) + err = obj2.Read(prot) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, obj.Addr, obj2.Addr) + require.Equal(t, obj.LogID, obj2.LogID) + require.Equal(t, "", obj2.Caller) + require.Equal(t, "", obj2.TrafficEnv.Name) + require.Equal(t, false, obj2.TrafficEnv.Open) + require.Equal(t, "", obj2.TrafficEnv.Env) + require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) + require.Equal(t, obj.Meta.IntMap[1].ID, obj2.Meta.IntMap[1].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) + require.Equal(t, obj.Meta.StrMap["1234"].ID, obj2.Meta.StrMap["1234"].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) + require.Equal(t, "b", obj2.Meta.List[0].ID) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "b", obj2.Meta.Set[0].ID) + require.Equal(t, 1, len(obj2.Meta.Set)) + fm.Recycle() +} diff --git a/test/golang/fieldmask/run_test.sh b/test/golang/fieldmask/run_test.sh new file mode 100755 index 00000000..5de97a56 --- /dev/null +++ b/test/golang/fieldmask/run_test.sh @@ -0,0 +1,37 @@ +#! /bin/bash -e + +# Copyright 2022 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +generate () { + xxx=$1 + out=gen-${xxx} + opt="go:package_prefix=example.com/test/${out}" + idl=a.thrift + if [ -d ${out} ]; then + rm -rf ${out} + fi + mkdir -p $out + + if [[ $xxx == new ]]; then + opt=$opt,with_field_mask,with_reflection + fi + echo "thriftgo -g $opt -o ${out} $idl" + thriftgo -g "$opt" -o ${out} $idl +} + +generate old +generate new +go mod tidy +go test -v ./... diff --git a/thrift_reflection/descriptor-extend.go b/thrift_reflection/descriptor-extend.go index 4094635c..21e5d908 100644 --- a/thrift_reflection/descriptor-extend.go +++ b/thrift_reflection/descriptor-extend.go @@ -388,12 +388,23 @@ func (td *TypeDescriptor) IsTypedef() bool { if ok { return cacheType == "typedef" } - sd, err := td.GetTypedefDescriptor() - isStruct := err == nil && sd != nil - if isStruct { - td.Extra["type"] = "typedef" + if td.IsContainer() || td.IsBasic() { + return false } - return isStruct + prefix, name := utils.ParseAlias(td.GetName()) + fd := LookupFD(td.Filepath) + if fd == nil { + return false + } + targetFd := fd.GetIncludeFD(prefix) + if targetFd == nil { + return false + } + if targetFd.GetTypedefDescriptor(name) == nil { + return false + } + td.Extra["type"] = "typedef" + return true } func (td *TypeDescriptor) GetTypedefDescriptor() (*TypedefDescriptor, error) { @@ -504,6 +515,18 @@ func (s *StructDescriptor) GetFieldByName(name string) *FieldDescriptor { return nil } +func (s *StructDescriptor) GetFieldById(id int32) *FieldDescriptor { + if s == nil { + return nil + } + for _, f := range s.Fields { + if f.ID == id { + return f + } + } + return nil +} + func (s *ConstValueDescriptor) GetValueAsString() string { t := s.GetType() if t == ConstValueType_INT { diff --git a/utils/ast_util.go b/utils/ast_util.go new file mode 100644 index 00000000..64d6e9a3 --- /dev/null +++ b/utils/ast_util.go @@ -0,0 +1,69 @@ +/** + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/thriftgo/parser" +) + +// reuse builtin types +var builtinTypes = map[string]thrift.TType{ + "void": thrift.VOID, + "bool": thrift.BOOL, + "byte": thrift.BYTE, + "i8": thrift.I08, + "i16": thrift.I16, + "i32": thrift.I32, + "i64": thrift.I64, + "double": thrift.DOUBLE, + "string": thrift.STRING, + "binary": thrift.STRING, + "list": thrift.LIST, + "map": thrift.MAP, + "set": thrift.SET, +} + +// TypeToStructLike try to find the defined parser.StructLike of a parser.Type in ast +func GetStructLike(name string, ast *parser.Thrift) *parser.StructLike { + tname := name + if _, ok := builtinTypes[tname]; ok { + return nil + } + typePkg, typeName := SplitSubfix(name) + if typePkg != "" { + ref, ok := ast.GetReference(typePkg) + if !ok { + return nil + } + ast = ref + } + if _, ok := ast.GetEnum(typeName); ok { + return nil + } + if typDef, ok := ast.GetTypedef(typeName); ok { + return GetStructLike(typDef.Type.Name, ast) + } + st, ok := ast.GetStruct(typeName) + if !ok { + st, ok = ast.GetUnion(typeName) + if !ok { + st, _ = ast.GetException(typeName) + } + } + return st +} diff --git a/utils/string_utils.go b/utils/string_utils.go index b531f844..a1c06244 100644 --- a/utils/string_utils.go +++ b/utils/string_utils.go @@ -190,3 +190,11 @@ func ParseKV(str string) (map[string]string, error) { } } } + +func SplitSubfix(t string) (typ, val string) { + idx := strings.LastIndex(t, ".") + if idx == -1 { + return "", t + } + return t[:idx], t[idx+1:] +} From 14e5dc2a52bf824b62298e89a4f4bd2341189565 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Wed, 22 Nov 2023 14:49:57 +0800 Subject: [PATCH 18/19] fix: fix os.ReadFile to ioutil.ReadFile for go 1.15 (#144) --- tool/trimmer/trim/config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tool/trimmer/trim/config.go b/tool/trimmer/trim/config.go index 28786bab..4779cb80 100644 --- a/tool/trimmer/trim/config.go +++ b/tool/trimmer/trim/config.go @@ -16,7 +16,7 @@ package trim import ( "fmt" - "os" + "io/ioutil" "path/filepath" "gopkg.in/yaml.v3" @@ -32,7 +32,7 @@ type YamlArguments struct { func ParseYamlConfig(path string) *YamlArguments { cfg := YamlArguments{} - dataBytes, err := os.ReadFile(filepath.Join(path, DefaultYamlFileName)) + dataBytes, err := ioutil.ReadFile(filepath.Join(path, DefaultYamlFileName)) if err != nil { return nil } From f8e4f29e8a2c07574e8c4c574beeaf318db0a7b3 Mon Sep 17 00:00:00 2001 From: tk_sky <63036400+tksky1@users.noreply.github.com> Date: Wed, 29 Nov 2023 14:25:56 +0800 Subject: [PATCH 19/19] feat: add trimmed-structure counter for trimmer tool (#148) --- generator/golang/backend.go | 5 +++-- tool/trimmer/main.go | 25 +++++++++++++++++++++++-- tool/trimmer/trim/mark.go | 6 +++--- tool/trimmer/trim/traversal.go | 8 +++++++- tool/trimmer/trim/trimmer.go | 31 +++++++++++++++++++++++++++---- tool/trimmer/trim/trimmer_test.go | 11 +++-------- 6 files changed, 66 insertions(+), 20 deletions(-) diff --git a/generator/golang/backend.go b/generator/golang/backend.go index df2514b2..57ac948f 100644 --- a/generator/golang/backend.go +++ b/generator/golang/backend.go @@ -86,11 +86,12 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R g.log = log g.prepareUtilities() if g.utils.Features().TrimIDL { - err := trim.TrimAST(&trim.TrimASTArg{Ast: req.AST, TrimMethods: nil, Preserve: nil}) + g.log.Warn("You Are Using IDL Trimmer") + structureTrimmed, fieldTrimmed, err := trim.TrimAST(&trim.TrimASTArg{Ast: req.AST, TrimMethods: nil, Preserve: nil}) if err != nil { g.log.Warn("trim error:", err.Error()) } - g.log.Warn("You Are Using IDL Trimmer") + g.log.Warn(fmt.Sprintf("removed %d unused structures with %d fields", structureTrimmed, fieldTrimmed)) } g.prepareTemplates() g.fillRequisitions() diff --git a/tool/trimmer/main.go b/tool/trimmer/main.go index fa97afac..ebb7611a 100644 --- a/tool/trimmer/main.go +++ b/tool/trimmer/main.go @@ -73,11 +73,13 @@ func main() { _, err = checker.CheckAll(ast) check(err) check(semantic.ResolveSymbols(ast)) + structs, fields := countStructs(ast) // trim ast - check(trim.TrimAST(&trim.TrimASTArg{ + _, _, err = trim.TrimAST(&trim.TrimASTArg{ Ast: ast, TrimMethods: a.Methods, Preserve: preserveInput, - })) + }) + check(err) // dump the trimmed ast to idl idl, err := dump.DumpIDL(ast) @@ -123,6 +125,8 @@ func main() { recurseDump(ast, a.Recurse, a.OutputFile) } else { check(writeStringToFile(a.OutputFile, idl)) + structsNew, fieldsNew := countStructs(ast) + fmt.Printf("removed %d unused structures with %d fields\n", structs-structsNew, fields-fieldsNew) } println("success, dump to", a.OutputFile) @@ -164,3 +168,20 @@ func writeStringToFile(filename, content string) error { } return nil } + +func countStructs(ast *parser.Thrift) (structs, fields int) { + structs += len(ast.Structs) + len(ast.Includes) + len(ast.Services) + len(ast.Unions) + len(ast.Exceptions) + for _, v := range ast.Structs { + fields += len(v.Fields) + } + for _, v := range ast.Services { + fields += len(v.Functions) + } + for _, v := range ast.Unions { + fields += len(v.Fields) + } + for _, v := range ast.Exceptions { + fields += len(v.Fields) + } + return structs, fields +} diff --git a/tool/trimmer/trim/mark.go b/tool/trimmer/trim/mark.go index 4bbdd190..91e6d563 100644 --- a/tool/trimmer/trim/mark.go +++ b/tool/trimmer/trim/mark.go @@ -36,12 +36,12 @@ func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename return } - if t.trimMethods == nil { + if len(t.trimMethods) == 0 { t.marks[filename][svc] = true } for _, function := range svc.Functions { - if t.trimMethods != nil { + if len(t.trimMethods) != 0 { funcName := svc.Name + "." + function.Name for i, method := range t.trimMethods { if ok, _ := method.MatchString(funcName); ok { @@ -55,7 +55,7 @@ func (t *Trimmer) markService(svc *parser.Service, ast *parser.Thrift, filename t.markFunction(function, ast, filename) } - if t.trimMethods != nil && (svc.Extends != "" || svc.Reference != nil) { + if len(t.trimMethods) != 0 && (svc.Extends != "" || svc.Reference != nil) { t.traceExtendMethod(svc, svc, ast, filename) } diff --git a/tool/trimmer/trim/traversal.go b/tool/trimmer/trim/traversal.go index bc4441e0..ec6823f8 100644 --- a/tool/trimmer/trim/traversal.go +++ b/tool/trimmer/trim/traversal.go @@ -34,6 +34,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { for i := range ast.Structs { if t.marks[filename][ast.Structs[i]] || t.checkPreserve(ast.Structs[i]) { listStruct = append(listStruct, ast.Structs[i]) + t.fieldsTrimmed -= len(ast.Structs[i].Fields) } } ast.Structs = listStruct @@ -42,6 +43,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { for i := range ast.Unions { if t.marks[filename][ast.Unions[i]] || t.checkPreserve(ast.Unions[i]) { listUnion = append(listUnion, ast.Unions[i]) + t.fieldsTrimmed -= len(ast.Unions[i].Fields) } } ast.Unions = listUnion @@ -50,6 +52,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { for i := range ast.Exceptions { if t.marks[filename][ast.Exceptions[i]] || t.checkPreserve(ast.Exceptions[i]) { listException = append(listException, ast.Exceptions[i]) + t.fieldsTrimmed -= len(ast.Exceptions[i].Fields) } } ast.Exceptions = listException @@ -57,7 +60,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { var listService []*parser.Service for i := range ast.Services { if t.marks[filename][ast.Services[i]] { - if t.trimMethods != nil { + if len(t.trimMethods) != 0 { var trimmedMethods []*parser.Function for j := range ast.Services[i].Functions { if t.marks[filename][ast.Services[i].Functions[j]] { @@ -67,6 +70,7 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { ast.Services[i].Functions = trimmedMethods } listService = append(listService, ast.Services[i]) + t.fieldsTrimmed -= len(ast.Services[i].Functions) } } ast.Services = listService @@ -74,4 +78,6 @@ func (t *Trimmer) traversal(ast *parser.Thrift, filename string) { for _, inc := range ast.Includes { inc.Used = nil } + + t.structsTrimmed -= len(ast.Structs) + len(ast.Includes) + len(ast.Services) + len(ast.Unions) + len(ast.Exceptions) } diff --git a/tool/trimmer/trim/trimmer.go b/tool/trimmer/trim/trimmer.go index fed4983c..b7cb3665 100644 --- a/tool/trimmer/trim/trimmer.go +++ b/tool/trimmer/trim/trimmer.go @@ -39,6 +39,8 @@ type Trimmer struct { preserveRegex *regexp.Regexp forceTrimming bool preservedStructs []string + structsTrimmed int + fieldsTrimmed int } type TrimASTArg struct { @@ -48,7 +50,7 @@ type TrimASTArg struct { } // TrimAST parse the cfg and trim the single AST -func TrimAST(arg *TrimASTArg) error { +func TrimAST(arg *TrimASTArg) (structureTrimmed int, fieldTrimmed int, err error) { var preservedStructs []string if wd, err := os.Getwd(); err == nil { cfg := ParseYamlConfig(wd) @@ -71,10 +73,11 @@ func TrimAST(arg *TrimASTArg) error { } // doTrimAST trim the single AST, pass method names if -m specified -func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, preservedStructs []string) error { +func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, preservedStructs []string) ( + structureTrimmed int, fieldTrimmed int, err error) { trimmer, err := newTrimmer(nil, "") if err != nil { - return err + return 0, 0, err } trimmer.asts[ast.Filename] = ast trimmer.trimMethods = make([]*regexp2.Regexp, len(trimMethods)) @@ -94,6 +97,7 @@ func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, pre check(err) } trimmer.preservedStructs = preservedStructs + trimmer.countStructs(ast) trimmer.markAST(ast) trimmer.traversal(ast, ast.Filename) if path := parser.CircleDetect(ast); len(path) > 0 { @@ -111,7 +115,7 @@ func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, pre } } - return nil + return trimmer.structsTrimmed, trimmer.fieldsTrimmed, nil } // Trim to trim thrift files to remove unused fields @@ -140,6 +144,25 @@ func Trim(files, includeDir []string, outDir string) error { return nil } +func (t *Trimmer) countStructs(ast *parser.Thrift) { + t.structsTrimmed += len(ast.Structs) + len(ast.Includes) + len(ast.Services) + len(ast.Unions) + len(ast.Exceptions) + for _, v := range ast.Structs { + t.fieldsTrimmed += len(v.Fields) + } + for _, v := range ast.Services { + t.fieldsTrimmed += len(v.Functions) + } + for _, v := range ast.Unions { + t.fieldsTrimmed += len(v.Fields) + } + for _, v := range ast.Exceptions { + t.fieldsTrimmed += len(v.Fields) + } + for _, v := range ast.Includes { + t.countStructs(v.Reference) + } +} + // make and init a trimmer with related parameters func newTrimmer(files []string, outDir string) (*Trimmer, error) { trimmer := &Trimmer{ diff --git a/tool/trimmer/trim/trimmer_test.go b/tool/trimmer/trim/trimmer_test.go index 0dd219ae..4cd3d0fe 100644 --- a/tool/trimmer/trim/trimmer_test.go +++ b/tool/trimmer/trim/trimmer_test.go @@ -24,13 +24,8 @@ import ( "github.com/cloudwego/thriftgo/semantic" ) -func TestTrimmer(t *testing.T) { - t.Run("trim AST", testSingleFile) - // t.Run("trim AST - test many", testMany) -} - // test single file ast trimming -func testSingleFile(t *testing.T) { +func TestSingleFile(t *testing.T) { trimmer, err := newTrimmer(nil, "") test.Assert(t, err == nil, err) filename := filepath.Join("..", "test_cases", "sample1.thrift") @@ -101,7 +96,7 @@ func TestTrimMethod(t *testing.T) { methods := make([]string, 1) methods[0] = "func1" - err = TrimAST(&TrimASTArg{ + _, _, err = TrimAST(&TrimASTArg{ Ast: ast, TrimMethods: methods, Preserve: nil, @@ -124,7 +119,7 @@ func TestPreserve(t *testing.T) { preserve := false - err = TrimAST(&TrimASTArg{ + _, _, err = TrimAST(&TrimASTArg{ Ast: ast, TrimMethods: nil, Preserve: &preserve,