From f86a89a3e49f5e3539fd275761c94a471c6376db Mon Sep 17 00:00:00 2001 From: Yi Duan Date: Tue, 5 Dec 2023 14:22:44 +0800 Subject: [PATCH] opt: more fieldmask API (#146) --- .github/workflows/push-check.yml | 10 - fieldmask/README.md | 113 ++++++- fieldmask/api_test.go | 279 ++++++++++++++-- fieldmask/mask.go | 103 +++--- fieldmask/path.go | 24 +- fieldmask/serdes.go | 385 ++++++++++++++++++++++ fieldmask/storage.go | 46 +-- fieldmask/utils.go | 116 ++++--- generator/golang/option.go | 7 +- generator/golang/read_write_context.go | 5 +- generator/golang/templates/processor.go | 4 + generator/golang/templates/struct.go | 115 ++++--- generator/golang/thrift.go | 42 ++- generator/golang/util.go | 18 +- internal/utils/b2s.go | 41 +++ test/golang/cases_and_options/go.mod | 7 +- test/golang/cases_and_options/go.sum | 27 ++ test/golang/cases_and_options/run_test.sh | 1 + test/golang/fieldmask/a.thrift | 42 ++- test/golang/fieldmask/b.thrift | 102 ++++++ test/golang/fieldmask/gen_test.go | 44 +++ test/golang/fieldmask/go.mod | 2 - test/golang/fieldmask/go.sum | 2 - test/golang/fieldmask/main_test.go | 385 ++++++++++++++++------ test/golang/fieldmask/run_test.sh | 29 +- 25 files changed, 1593 insertions(+), 356 deletions(-) create mode 100644 fieldmask/serdes.go create mode 100644 internal/utils/b2s.go create mode 100644 test/golang/fieldmask/b.thrift create mode 100644 test/golang/fieldmask/gen_test.go diff --git a/.github/workflows/push-check.yml b/.github/workflows/push-check.yml index 3ab8c367..ff3e5173 100644 --- a/.github/workflows/push-check.yml +++ b/.github/workflows/push-check.yml @@ -30,16 +30,6 @@ jobs: - name: Lint run: | - go install mvdan.cc/gofumpt@v0.2.0 - echo "install done!" - set -e - if [[ -n "$(gofumpt -l .)" ]]; then - echo "gofumpt found formatting issues." - gofumpt -l . - exit 1 - fi - test -z "$(gofumpt -l .)" - echo "test done!" go vet -stdmethods=false $(go list ./...) echo "go vet done!" diff --git a/fieldmask/README.md b/fieldmask/README.md index f2e844a1..9acd2235 100644 --- a/fieldmask/README.md +++ b/fieldmask/README.md @@ -13,10 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. --> +# Thrift FieldMask RFC -# ThriftPath RFC +## What is thrift fieldmask? +FieldMask is inspired by [Protobuf](https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask) and used to indicates the data that users care about, and filter out useless data, during a RPC call, in order to reducing network package size and accelerating serializing/deserializing process. This tech has been widely used among Protobuf [services](https://netflixtechblog.com/practical-api-design-at-netflix-part-1-using-protobuf-fieldmask-35cfdc606518). -## What is thrift path? +## How to construct a fieldmask? +To construct a fieldmask, you need two things: + - [Thrift Path](#thrift-path) for describing the data you want + - [Type Descriptor](#type-descriptor) for validating the thrift path you pass is compatible with thrift message definition (IDL) + +### Thrift Path + +#### What is thrift path? A path string represents a arbitrary endpoint of thrift object. It is used for locating data from thrift root message, and defined from-top-to-bottom. For exapmle, a thrift message defined as below: ```thrift @@ -29,12 +38,13 @@ struct Example { A thrift path `$.Foo` represents the string value of Example.Foo, and `$.Self.Bar` represents the secondary layer i64 value of Example.Self.Bar Since thrift has four nesting types (LIST/SET/MAP/STRUCT), thrift path should also support locating elements in all these types' object, not only STRUCT. -## Syntax +#### Syntax Here are basic hypothesis: - `fieldname` is the field name of a field in a struct, it **MUST ONLY** contain '[a-zA-Z]' alphabet letters, integer numbers and char '_'. - `index` is the index of a element in a list or set, it **MUST ONLY** contain integer numbers. - `key` is the string-typed key of a element in a map, it can contain any letters, but it **MUST** be a quoted string. - `id` is the integer-typed key of a element in a map, it **MUST ONLY** contain integer numbers. +- except `key`, ThriftPath shouldn't contains any blank chars (\n\r\b\t). Here is detailed syntax: @@ -42,9 +52,98 @@ ThriftPath | Description -- | -- $ | the root object,every path must start with it. .`fieldname` | get the child field of a struct corepsonding to fieldname. For example, `$.FieldA.ChildrenB` -[`index`,index...] | get any number of elements in an List/Set corepsonding to indices. Indices must be integer.For example: `$.FieldList[1,3,4]` .Notice: a index beyond actual list size can written but is useless. -{"key","key"...} | get any number of values corepsonding to key in a string-typed-key map. For example: `$.StrMap{"abcd","1234"}` -{id,id...} | get the child field with specific id in a integer-typed-key map. For example, `$.IntMap{1,2}` -\* | get **ALL** fields/elements, that is: `$.StrMap{*}.FieldX` menas gets all the elements' FieldX in a map Root.StrMap; `$.List[*].FieldX` means get all the elements' FieldX in a list Root.List +[`index`,`index`...] | get any number of elements in an List/Set corepsonding to indices. Indices must be integer.For example: `$.FieldList[1,3,4]` .Notice: a index beyond actual list size can written but is useless. +{"`key`","`key`"...} | get any number of values corepsonding to key in a string-typed-key map. For example: `$.StrMap{"abcd","1234"}` +{`id`,`id`...} | get the child field with specific id in a integer-typed-key map. For example, `$.IntMap{1,2}` +\* | get **ALL** fields/elements, that is: `$.StrMap{*}.FieldX` menas gets all the elements' FieldX in a map Root.StrMap; `$.List[*].FieldX` means get all the elements' FieldX in a list Root.List. +#### Agreement Of Implementation +- A field in mask means "PASS" (**will be** serialized/deserialized), and the other field not in mask means "Filtered" ((**won't be** serialized/deserialized)) +- A empty mask means "PASS ALL" (all field is "PASS") +- For map of neither-string-nor-integer typed key, only syntax token of all '*' (see above) is allowed in. +- Required fields CAN be not in mask ("Filtered") while they will still be written as zero values. +- FieldMask settings must start from the root object. + - Tips: If you want to set FieldMask from a non-root object and make it effective, you need to add `field_mask_halfway` option and regenerate the codes. However, there is a latent risk: if different parent objects reference the same child object, and these two parent objects set different fieldmasks, only one parent object's fieldmask relative to this child object will be effective. + +### Type Descriptor +Type descriptor is the runtime representation of a message definition, in aligned with [Protobuf Descriptor](https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/descriptor.proto). To get a type descriptor, you must enable thrift reflection feature first, which was introduced in thriftgo [v0.3.0](https://github.com/cloudwego/thriftgo/pull/83). you can generate related codes for this feature using option `with_reflection`. + +## How to use fieldmask? +1. First, you must generates codes for this feature using two options `with_fieldmask` and `with_reflection` +``` +$ thriftgo -g with_field_mask,with_reflection ${your_idl} +``` +2. Create a fieldmask in the initializing phase of your application (recommanded), or just in the bizhandler before you return a response +```go +import ( + "sync" + "github.com/cloudwego/thriftgo/fieldmask" + nbase "github.com/cloudwego/thriftgo/test/golang/fieldmask/gen-new/base" +) + +var fieldmaskCache sync.Map + +func init() { + // new a obj to get its TypeDescriptor + obj := nbase.NewBase() + desc := obj.GetTypeDescriptor() + + // construct a fieldmask with TypeDescriptor and thrift pathes + fm, err := fieldmask.NewFieldMask(desc, + "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") + if err != nil { + panic(err) + } + + // cache it for future usage of nbase.Base + fieldmaskCache.Store("Mask1ForBase", fm) +} +``` +3. Now you can use fieldmask in either client-side or server-side + - For server-side, you can set fieldmask with generated API `Set_FieldMask()` on your response object. Then the object itself will notice the fieldmask and using it during its serialization + ```go + func bizHandler(req any) (*nbase.Base) { + // handle request ... + + // biz logic: handle and get final response object + obj := bizBase() + + // Load fieldmask from cache + fm, _ := fieldmaskCache.Load("Mask1ForBase") + if fm != nil { + // load ok, set fieldmask onto the object using codegen API + obj.Set_FieldMask(fm.(*fieldmask.FieldMask)) + } + + return obj + } + ``` + - For client-side: related to the deserialization process of framework. For kitex, it's WIP. + +## How to pass fieldmask between programs? +Generally, you can add one binary field on your request definition to carry fieldmask, and explicitly serialize/deserialize the fieldmask you are using into/from this field. We provide two encapsulated API for serialization/deserialization: +- [FieldMask.MarshalJSON()/UnmarshalJSON()](serdes.go): Object methods, serialize/deserialize fieldmask into/from json bytes +- [thriftgo/fieldmask.Marshal()/Unmarshal()](serdes.go): Package functions, serialize/deserialize fieldmask into/from binary bytes. We recommand you to use this API rather than the last one, because it is **much faster** due to using cache -- Unless your application is lack of memory. + + +## Benchmark +See [(main_test.go)](../test/golang/fieldmask/main_test.go) +``` +goos: darwin +goarch: amd64 +pkg: github.com/cloudwego/thriftgo/test/golang/fieldmask +cpu: Intel(R) Core(TM) i9-9880H CPU @ 2.30GHz +BenchmarkWriteWithFieldMask/old-16 2188 ns/op 0 B/op 0 allocs/op +BenchmarkWriteWithFieldMask/new-16 2281 ns/op 0 B/op 0 allocs/op +BenchmarkWriteWithFieldMask/new-mask-half-16 1055 ns/op 0 B/op 0 allocs/op +BenchmarkReadWithFieldMask/old-16 6187 ns/op 2124 B/op 41 allocs/op +BenchmarkReadWithFieldMask/new-16 5675 ns/op 2268 B/op 41 allocs/op +BenchmarkReadWithFieldMask/new-mask-half-16 4762 ns/op 1564 B/op 31 allocs/op +``` +Explain case names: +- Write: serialization test +- Read: deserializtion test +- old: not generate with_fieldmask API +- new: generate with_fieldmask API, but not use fieldmask +- new-mask-half: generate with_fieldmask API and use fieldmask to mask half of the data \ No newline at end of file diff --git a/fieldmask/api_test.go b/fieldmask/api_test.go index 05d9c160..f0904869 100644 --- a/fieldmask/api_test.go +++ b/fieldmask/api_test.go @@ -17,6 +17,9 @@ package fieldmask import ( + "encoding/json" + "runtime" + "strconv" "strings" "testing" @@ -62,11 +65,36 @@ struct MetaInfo { 3: Base Base, } +typedef Val Key + struct BaseResp { - 1: string StatusMessage = "", - 2: i32 StatusCode = 0, - 3: optional map Extra, -}` + 1: required string StatusMessage = "", + 2: required i32 StatusCode = 0, + 3: required bool R3, + 4: required byte R4, + 5: required i16 R5, + 6: required i64 R6, + 7: required double R7, + 8: required string R8, + 9: required Ex R9, + 10: required list R10, + 11: required set R11, + 12: required TrafficEnv R12, + 13: required map R13, + 0: required Key R0, + + 14: map F1 + 15: map F2, + 16: list F3 + 17: set F4, + 18: map F5 + 19: map F6 + 110: map F7 + 111: map> F8 + 112: list>> F9 + 113: map F10 +} +` func GetDescriptor(IDL string, root string) *thrift_reflection.TypeDescriptor { ast, err := parser.ParseString("a.thrift", IDL) @@ -95,6 +123,16 @@ func TestNewFieldMask(t *testing.T) { args args want *FieldMask }{ + { + name: "Neither-string-nor-integer-key Map", + args: args{ + IDL: baseIDL, + rootStruct: "BaseResp", + paths: []string{"$.F10{*}.A", "$.F5{*}.A"}, + inMasks: []string{"$.F10{\"a\"}.A", "$.F5{0}.A"}, + notInMasks: []string{`$.F10{"a"}.B`, "$.F10{*}.B", "$.F5{0}.B", "$.F5{*}.B"}, + }, + }, { name: "Struct", args: args{ @@ -137,6 +175,16 @@ func TestNewFieldMask(t *testing.T) { notInMasks: []string{"$.Extra[0].List", "$.Meta.F2[0].LogID"}, }, }, + { + name: "Repeated *", + args: args{ + IDL: baseIDL, + rootStruct: "Base", + paths: []string{"$.Extra[*].List", "$.Extra[*].Set", "$.Meta.F2{*}.Caller", "$.Meta.F2{*}.Addr"}, + inMasks: []string{"$.Extra[*].Set[0]", "$.Meta.F2{1}.Addr"}, + notInMasks: []string{"$.Meta.F2[0].LogID"}, + }, + }, { name: "String Map", args: args{ @@ -173,16 +221,36 @@ func TestNewFieldMask(t *testing.T) { retry := true begin: - println("fieldmask:") - println(got.String(st)) + // println("fieldmask:") + // println(got.String(st)) // spew.Dump(got) + // test marshal json + // println("marshal:") + out, err := got.MarshalJSON() + if err != nil { + t.Fatal(err) + } + // println(string(out)) + if !json.Valid(out) { + t.Fatal("not invalid json") + } + + // test unmarshal json + nn := &FieldMask{} + if err := nn.UnmarshalJSON(out); err != nil { + t.Fatal(err) + } + if tt.name != "Union" { for _, path := range tt.args.paths { println("[paths] ", path) if !got.PathInMask(st, path) { t.Fatal(path) } + if !nn.PathInMask(st, path) { + t.Fatal(path) + } } } @@ -191,12 +259,18 @@ func TestNewFieldMask(t *testing.T) { if !got.PathInMask(st, path) { t.Fatal(path) } + if !nn.PathInMask(st, path) { + t.Fatal(path) + } } for _, path := range tt.args.notInMasks { println("[notInMasks] ", path) if got.PathInMask(st, path) { t.Fatal(path) } + if nn.PathInMask(st, path) { + t.Fatal(path) + } } if retry { @@ -211,6 +285,22 @@ func TestNewFieldMask(t *testing.T) { } } +func TestMarshalJSONStable(t *testing.T) { + st := GetDescriptor(baseIDL, "MetaInfo") + fm, err := NewFieldMask(st, "$.F2{4,1,3}", "$.F2{0,2}", `$.F1{"c","d","b"}`, `$.F1{"a"}`) + if err != nil { + t.Fatal(err) + } + jo, err := fm.MarshalJSON() + if err != nil { + t.Fatal(err) + } + println(string(jo)) + if string(jo) != (`{"path":"$","type":"Struct","children":[{"path":1,"type":"StrMap","children":[{"path":"a","type":"Struct"},{"path":"b","type":"Struct"},{"path":"c","type":"Struct"},{"path":"d","type":"Struct"}]},{"path":2,"type":"IntMap","children":[{"path":0,"type":"Struct"},{"path":1,"type":"Struct"},{"path":2,"type":"Struct"},{"path":3,"type":"Struct"},{"path":4,"type":"Struct"}]}]}`) { + t.Fatal(string(jo)) + } +} + func TestErrors(t *testing.T) { type args struct { IDL string @@ -224,7 +314,7 @@ func TestErrors(t *testing.T) { want *FieldMask }{ { - name: "desc struct", + name: "desc expect struct", args: args{ IDL: baseIDL, rootStruct: "Base", @@ -233,7 +323,7 @@ func TestErrors(t *testing.T) { }, }, { - name: "desc list", + name: "desc expect list", args: args{ IDL: baseIDL, rootStruct: "Base", @@ -242,7 +332,7 @@ func TestErrors(t *testing.T) { }, }, { - name: "desc map", + name: "desc expect map", args: args{ IDL: baseIDL, rootStruct: "Base", @@ -251,7 +341,7 @@ func TestErrors(t *testing.T) { }, }, { - name: "desc map key", + name: "desc expect map int key", args: args{ IDL: baseIDL, rootStruct: "ExtraInfo", @@ -260,7 +350,7 @@ func TestErrors(t *testing.T) { }, }, { - name: "desc map key", + name: "desc expect map string key", args: args{ IDL: baseIDL, rootStruct: "ExtraInfo", @@ -283,7 +373,7 @@ func TestErrors(t *testing.T) { IDL: baseIDL, rootStruct: "Base", path: []string{"$.TrafficEnv", "$.TrafficEnv.Env"}, - err: `onflicts with previously-set all (*) fields`, + err: `field conflicts with previously settled '*'`, }, }, { @@ -292,7 +382,7 @@ func TestErrors(t *testing.T) { IDL: baseIDL, rootStruct: "Base", path: []string{"$.Extra[*]", "$.Extra[1]"}, - err: `onflicts with previously-set all (*) index`, + err: `id conflicts with previously settled '*'`, }, }, { @@ -301,7 +391,16 @@ func TestErrors(t *testing.T) { IDL: baseIDL, rootStruct: "ExtraInfo", path: []string{"$.IntMap{*}", "$.IntMap{1}"}, - err: `onflicts with previously-set all (*) keys`, + err: `key conflicts with previous settled '*'`, + }, + }, + { + name: "key conflict2", + args: args{ + IDL: baseIDL, + rootStruct: "BaseResp", + path: []string{"$.F5{*}", "$.F5{1}"}, + err: `key conflicts with previous settled '*'`, }, }, { @@ -327,7 +426,7 @@ func TestErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { st := GetDescriptor(tt.args.IDL, tt.args.rootStruct) - _, err := GetFieldMask(st, tt.args.path...) + _, err := NewFieldMask(st, tt.args.path...) if err == nil || !strings.Contains(err.Error(), tt.args.err) { t.Fatal(err) } @@ -350,16 +449,16 @@ func BenchmarkNewFieldMask(b *testing.B) { _ = fm } }) - b.Run("reuse", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - fm, err := GetFieldMask(st, []string{"$.LogID", "$.TrafficEnv.Open", "$.TrafficEnv.Env", "$.Extra[0]", "$.Extra[1].IntMap{0}", "$.Extra[2].StrMap{\"abcd\"}"}...) - if err != nil { - b.Fatal(err) - } - fm.Recycle() - } - }) + // b.Run("reuse", func(b *testing.B) { + // b.ResetTimer() + // for i := 0; i < b.N; i++ { + // fm, err := GetFieldMask(st, []string{"$.LogID", "$.TrafficEnv.Open", "$.TrafficEnv.Env", "$.Extra[0]", "$.Extra[1].IntMap{0}", "$.Extra[2].StrMap{\"abcd\"}"}...) + // if err != nil { + // b.Fatal(err) + // } + // fm.Recycle() + // } + // }) } func BenchmarkFieldMask_InMask(b *testing.B) { @@ -441,3 +540,133 @@ func BenchmarkFieldMask_InMask(b *testing.B) { } }) } + +func BenchmarkMarshal(b *testing.B) { + st := GetDescriptor(baseIDL, "Base") + got, err := NewFieldMask(st, "$.Extra[0].List", "$.Extra[*].Set", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr") + if err != nil { + b.Fatal(err) + } + j, err := got.MarshalJSON() + if err != nil { + b.Fatal(err) + } + if !json.Valid(j) { + b.Fatal("invalid json:", string(j)) + } + j2, e2 := Marshal(got) + if e2 != nil { + b.Fatal(e2) + } + if !json.Valid(j2) { + b.Fatal("invalid json2", string(j2)) + } + + b.Run("MarshalJSON", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = got.MarshalJSON() + } + }) + + b.Run("Marshal", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = Marshal(got) + } + }) +} + +func BenchmarkUnmarshal(b *testing.B) { + st := GetDescriptor(baseIDL, "Base") + got, err := NewFieldMask(st, "$.Extra[0].List", "$.Extra[*].Set", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr") + if err != nil { + b.Fatal(err) + } + j, err := got.MarshalJSON() + if err != nil { + b.Fatal(err) + } + if !json.Valid(j) { + b.Fatal("invalid json:", string(j)) + } + act := new(FieldMask) + if err := act.UnmarshalJSON(j); err != nil { + b.Fatal(err) + } + // if !reflect.DeepEqual(got, act) { + // b.Fatal() + // } + + _, err = Unmarshal(j) + if err != nil { + b.Fatal(err) + } + // if !reflect.DeepEqual(got, act2) { + // b.Fatal() + // } + + b.Run("UnmarshalJSON", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + act := new(FieldMask) + _ = act.UnmarshalJSON(j) + } + }) + + b.Run("Umarshal", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = Unmarshal(j) + } + }) +} + +func BenchmarkMemory(b *testing.B) { + st := GetDescriptor(baseIDL, "Base") + _, err := NewFieldMask(st, []string{"$.Extra[0].List", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr"}...) + if err != nil { + b.Fatal(err) + } + + go func() { + for { + runtime.GC() + } + }() + + tester := func(X int, b *testing.B) { + for i := 0; i < b.N; i++ { + for x := 0; x < X; x++ { + tt, err := NewFieldMask(st, "$.Extra["+strconv.Itoa(x)+"].List", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr") + if err != nil { + b.Fatal(err) + } + j, err := Marshal(tt) + if err != nil { + b.Fatal(err) + } + _, err = Unmarshal(j) + if err != nil { + b.Fatal(err) + } + } + } + } + + b.Run("10", func(b *testing.B) { + tester(10, b) + }) + + b.Run("100", func(b *testing.B) { + tester(100, b) + }) + + b.Run("1000", func(b *testing.B) { + tester(1000, b) + }) + + b.Run("10000", func(b *testing.B) { + tester(10000, b) + }) +} diff --git a/fieldmask/mask.go b/fieldmask/mask.go index 425d681b..488875c8 100644 --- a/fieldmask/mask.go +++ b/fieldmask/mask.go @@ -19,29 +19,74 @@ package fieldmask import ( "fmt" "strings" - "sync" + "github.com/cloudwego/thriftgo/internal/utils" "github.com/cloudwego/thriftgo/thrift_reflection" ) -type fieldMaskType uint8 +// FieldMaskType indicates the corresponding thrift message type for a fieldmask +type FieldMaskType uint8 + +// MarshalText implements encoding.TextMarshaler +func (ft FieldMaskType) MarshalText() ([]byte, error) { + switch ft { + case FtScalar: + return utils.S2B("Scalar"), nil + case FtList: + return utils.S2B("List"), nil + case FtStruct: + return utils.S2B("Struct"), nil + case FtStrMap: + return utils.S2B("StrMap"), nil + case FtIntMap: + return utils.S2B("IntMap"), nil + default: + return utils.S2B("Invalid"), nil + } +} +// UnmarshalText implements encoding.TextUnmarshaler +func (ft *FieldMaskType) UnmarshalText(in []byte) error { + switch utils.B2S(in) { + case "Scalar": + *ft = FtScalar + case "List": + *ft = FtList + case "Struct": + *ft = FtStruct + case "StrMap": + *ft = FtStrMap + case "IntMap": + *ft = FtIntMap + default: + *ft = FtInvalid + } + return nil +} + +// FieldMaskType Enums const ( - ftInvalid fieldMaskType = iota - ftScalar - ftArray - ftStruct - ftStrMap - ftIntMap + // Invalid or unsupported thrift type + FtInvalid FieldMaskType = iota + // thrift scalar types, including BOOL/I8/I16/I32/I64/DOUBLE/STRING/BINARY, or neither-string-nor-integer-typed-key MAP + FtScalar + // thrift LIST/SET + FtList + // thrift STRUCT + FtStruct + // thrift MAP with string-typed key + FtStrMap + // thrift MAP with integer-typed key + FtIntMap ) // FieldMask represents a collection of thrift pathes // See type FieldMask struct { - typ fieldMaskType - isAll bool + typ FieldMaskType + all *FieldMask fdMask *fieldMap @@ -51,12 +96,6 @@ type FieldMask struct { intMask intMap } -var fmsPool = sync.Pool{ - New: func() interface{} { - return &FieldMask{} - }, -} - // NewFieldMask create a new fieldmask func NewFieldMask(desc *thrift_reflection.TypeDescriptor, pathes ...string) (*FieldMask, error) { ret := FieldMask{} @@ -67,22 +106,6 @@ func NewFieldMask(desc *thrift_reflection.TypeDescriptor, pathes ...string) (*Fi return &ret, nil } -// GetFieldMask reuse fieldmask from pool -func GetFieldMask(desc *thrift_reflection.TypeDescriptor, paths ...string) (*FieldMask, error) { - ret := fmsPool.Get().(*FieldMask) - err := ret.init(desc, paths...) - if err != nil { - return nil, err - } - return ret, nil -} - -// Recycle puts fieldmask into pool -func (self *FieldMask) Recycle() { - self.reset() - fmsPool.Put(self) -} - // reset clears fieldmask's all path func (self *FieldMask) reset() { if self == nil { @@ -107,20 +130,6 @@ func (self *FieldMask) init(desc *thrift_reflection.TypeDescriptor, paths ...str // String pretty prints the structure a FieldMask represents // -// For example: -// pathes `[]string{"$.Extra[0].List", "$.Extra[*].Set", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr"}` will print: -// -// (Base) -// .Extra (list) -// [ -// * -// ] -// .Meta (MetaInfo) -// .F2 (map) -// { -// * -// } -// // WARING: This is unstable API, the printer format is not guaranteed func (self FieldMask) String(desc *thrift_reflection.TypeDescriptor) string { buf := strings.Builder{} @@ -178,7 +187,7 @@ func (self *FieldMask) All() bool { return true } switch self.typ { - case ftStruct, ftArray, ftIntMap, ftStrMap: + case FtStruct, FtList, FtIntMap, FtStrMap: return self.isAll default: return true diff --git a/fieldmask/path.go b/fieldmask/path.go index c3dc0734..b59d7653 100644 --- a/fieldmask/path.go +++ b/fieldmask/path.go @@ -59,6 +59,11 @@ const ( pathSepSlash = '\\' ) +const ( + jsonPathAny = `*` + jsonPathRoot = `$` +) + type pathValue struct { pv unsafe.Pointer iv int @@ -325,6 +330,8 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path } styp := stok.Type() // println("stoken: ", stok.String()) + // j, _ := cur.MarshalJSON() + // println("cur mask: ", string(j), cur.isAll, cur.all) if styp == pathTypeRoot { continue @@ -335,7 +342,7 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path return false } // println("struct: ", st.Name) - if cur.typ != ftStruct { + if cur.typ != FtStruct { return false } @@ -352,7 +359,6 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path if f == nil { return false } - } else if typ == pathTypeLitStr { name := tok.val.Str() f = st.GetFieldByName(name) @@ -369,8 +375,10 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path // println("all", all, "FieldInMask:", cur.FieldInMask(int32(f.GetID()))) // check if name set mask + // println("field ", f.GetID()) nextFm, exist := cur.Field(int16(f.GetID())) if !exist { + // println("return false") return false } @@ -392,7 +400,7 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path return false } - if cur.typ != ftArray { + if cur.typ != FtList { return false } @@ -433,7 +441,6 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path // next fieldmask curDesc = et cur = next - } else if styp == pathTypeMapL { // get element and key desc if !curDesc.IsMap() { @@ -449,10 +456,10 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path } // println("cur.typ::", cur.typ, "cur::", cur.String(curDesc)) - if cur.typ != ftIntMap && cur.typ != ftStrMap { + if cur.typ != FtIntMap && cur.typ != FtStrMap && cur.typ != FtScalar { return false } - + // spew.Dump("cur ", cur) next := cur.all // iter indexies... for it.HasNext() { @@ -474,7 +481,7 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path } if typ == pathTypeLitInt { - if cur.typ != ftIntMap { + if cur.typ != FtIntMap { return false } v := tok.val.Int() @@ -485,7 +492,7 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path // NOTICE: always use last elem's fieldmask next = nextFm } else if typ == pathTypeStr { - if cur.typ != ftStrMap { + if cur.typ != FtStrMap { return false } v := tok.val.Str() @@ -503,6 +510,7 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path // next fieldmask curDesc = et cur = next + // spew.Dump("next ", cur) } else { return false } diff --git a/fieldmask/serdes.go b/fieldmask/serdes.go new file mode 100644 index 00000000..34f33e1f --- /dev/null +++ b/fieldmask/serdes.go @@ -0,0 +1,385 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fieldmask + +import ( + "encoding/json" + "errors" + "fmt" + "sort" + "strconv" + "sync" + + "github.com/cloudwego/thriftgo/internal/utils" +) + +var bytesPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, 0, 4096) + return &b + }, +} + +// MarshalJSON marshals the fieldmask into json. +// +// For example: +// - pathes `[]string{"$.Extra[0].List", "$.Extra[*].Set", "$.Meta.F2{0}", "$.Meta.F2{*}.Addr"}` will produces: +// - `{"path":"$","type":"Struct","children":[{"path":6,"type":"List","children":[{"path":"*","type":"Struct","children":[{"path":4,"type":"List"}]}]},{"path":256,"type":"Struct","children":[{"path":2,"type":"IntMap","children":[{"path":"*","type":"Struct","children":[{"path":0,"type":"Scalar"}]}]}]}]}` +// +// For details: +// - `path` is the path segment of current fieldmask layer +// - `type` is the `FieldMaskType` of the fieldmask +// -`children` is the chidlren of subsequent pathes +// - each fieldmask always starts with root path "$" +// - path "*" indicates all subsequent path of the fieldmask shares the same sub fieldmask +func (fm *FieldMask) MarshalJSON() ([]byte, error) { + buf := bytesPool.Get().(*[]byte) + + err := fm.marshalBegin(buf) + if err != nil { + (*buf) = (*buf)[:0] + bytesPool.Put(buf) + return nil, err + } + + ret := make([]byte, len(*buf)) + copy(ret, *buf) + (*buf) = (*buf)[:0] + bytesPool.Put(buf) + return ret, nil +} + +func write(buf *[]byte, str string) { + *buf = append(*buf, str...) +} + +func (self *FieldMask) marshalBegin(buf *[]byte) error { + if self == nil { + write(buf, "{}") + return nil + } + write(buf, `{"path":"$","type":"`) + out, _ := self.typ.MarshalText() + *buf = append(*buf, out...) + write(buf, `"`) + return self.marshalRec(buf) +} + +type ivalue struct { + id int + fm *FieldMask +} + +type isorter []ivalue + +func (self isorter) Len() int { + return len(self) +} + +func (self isorter) Less(i, j int) bool { + return self[i].id < self[j].id +} + +func (self isorter) Swap(i, j int) { + self[i], self[j] = self[j], self[i] +} + +type svalue struct { + id string + fm *FieldMask +} + +type ssorter []svalue + +func (self ssorter) Len() int { + return len(self) +} + +func (self ssorter) Less(i, j int) bool { + return self[i].id < self[j].id +} + +func (self ssorter) Swap(i, j int) { + self[i], self[j] = self[j], self[i] +} + +func (self *FieldMask) marshalRec(buf *[]byte) error { + if self.All() && self.all == nil { + write(buf, "}") + return nil + } + + var start bool + var writer = func(path interface{}, f *FieldMask) (bool, error) { + if !f.Exist() { + return true, nil + } + if start { + write(buf, ",") + } + + // write path + write(buf, `{"path":`) + switch v := path.(type) { + case int: + *buf = strconv.AppendInt(*buf, int64(v), 10) + case string: + *buf = strconv.AppendQuote(*buf, v) + } + write(buf, ",") + + // write type + write(buf, `"type":"`) + typ, _ := f.typ.MarshalText() + *buf = append(*buf, typ...) + write(buf, `"`) + + if err := f.marshalRec(buf); err != nil { + return false, err + } + + start = true + return true, nil + } + + // write children + write(buf, `,"children":[`) + + if self.All() { + _, err := writer(jsonPathAny, self.all) + if err != nil { + return err + } + + } else if self.typ == FtStruct { + fds := make(isorter, 0, len(self.fdMask.tail)*2) + for id, f := range self.fdMask.head { + if !f.Exist() { + continue + } + fds = append(fds, ivalue{id, f}) + } + for id, f := range self.fdMask.tail { + if !f.Exist() { + continue + } + fds = append(fds, ivalue{int(id), f}) + } + sort.Stable(fds) + for _, v := range fds { + cont, err := writer(v.id, v.fm) + if err != nil { + return err + } + if !cont { + break + } + } + + } else if self.typ == FtList || self.typ == FtIntMap { + fds := make(isorter, 0, len(self.intMask)) + for k, f := range self.intMask { + if !f.Exist() { + continue + } + fds = append(fds, ivalue{int(k), f}) + } + sort.Stable(fds) + for _, v := range fds { + cont, err := writer(v.id, v.fm) + if err != nil { + return err + } + if !cont { + break + } + } + + } else if self.typ == FtStrMap { + fds := make(ssorter, 0, len(self.strMask)) + for k, f := range self.strMask { + if !f.Exist() { + continue + } + fds = append(fds, svalue{k, f}) + } + sort.Stable(fds) + for _, v := range fds { + cont, err := writer(v.id, v.fm) + if err != nil { + return err + } + if !cont { + break + } + } + + } else { + return errors.New("invalid fieldmask type") + } + + write(buf, "]}") + return nil +} + +type shadowFieldMask struct { + Path interface{} `json:"path"` + Type FieldMaskType `json:"type"` + Children []shadowFieldMask `json:"children"` +} + +// UnmarshalJSON unmarshal the fieldmask from json. +// +// Input JSON **MUST** be according to the schema of `FieldMask.MarshalJSON()` +func (self *FieldMask) UnmarshalJSON(in []byte) error { + if self == nil { + return errors.New("nil memory address") + } + var s = new(shadowFieldMask) + if err := json.Unmarshal(in, &s); err != nil { + return err + } + // spew.Dump(s) + if s.Path != jsonPathRoot { + return errors.New("fieldmask must begin with root path '$'") + } + return self.fromShadow(s) +} + +func (self *FieldMask) fromShadow(s *shadowFieldMask) error { + if s == nil || s.Type == FtInvalid { + return errors.New("invalid fieldmask type") + } + self.typ = s.Type + + if len(s.Children) == 0 { + self.isAll = true + return nil + } + + if s.Type == FtScalar { + is, err := self.checkAll(&s.Children[0]) + if err != nil { + return err + } + if !is { + return errors.New("expect * for the child") + } + return nil + } else if s.Type == FtStruct { + for _, n := range s.Children { + if is, err := self.checkAll(&n); err != nil { + return err + } else if is { + return nil + } + id, ok := n.Path.(float64) + if !ok { + return fmt.Errorf("expect number but got %#v", n.Path) + } + next := self.setFieldID(fieldID(id), n.Type) + if err := next.fromShadow(&n); err != nil { + return err + } + } + + } else if s.Type == FtList || s.Type == FtIntMap { + for _, n := range s.Children { + if is, err := self.checkAll(&n); err != nil { + return err + } else if is { + return nil + } + id, ok := n.Path.(float64) + if !ok { + return fmt.Errorf("expect number but got %#v", n.Path) + } + next := self.setInt(int(id), n.Type, len(s.Children)) + if err := next.fromShadow(&n); err != nil { + return err + } + } + + } else if s.Type == FtStrMap { + for _, n := range s.Children { + if is, err := self.checkAll(&n); err != nil { + return err + } else if is { + return nil + } + id, ok := n.Path.(string) + if !ok { + return fmt.Errorf("expect string but got %#v", n.Path) + } + next := self.setStr(id, n.Type, len(s.Children)) + if err := next.fromShadow(&n); err != nil { + return err + } + } + } + + return nil +} + +func (self *FieldMask) checkAll(s *shadowFieldMask) (bool, error) { + if s.Path == "*" { + self.isAll = true + self.all = &FieldMask{} + return true, self.all.fromShadow(s) + } + return false, nil +} + +var ( + fm2json sync.Map + json2fm sync.Map +) + +// Marshal serializes a fieldmask into bytes. +// +// Notice: This API uses cache to accelerate processing, +// at the cost of increasing memory usage +func Marshal(fm *FieldMask) ([]byte, error) { + // fast-path: load from cache + if j, ok := fm2json.Load(fm); ok { + return j.([]byte), nil + } + // slow-path: marshal from object + nj, err := fm.MarshalJSON() + if err != nil { + return nil, err + } + fm2json.Store(fm, nj) + return nj, nil +} + +// Marshal deserializes a fieldmask from bytes. +// +// Notice: This API uses cache to accelerate processing, +// at the cost of increasing memory usage +func Unmarshal(data []byte) (*FieldMask, error) { + // fast-path: load from cache + if fm, ok := json2fm.Load(utils.B2S(data)); ok { + return fm.(*FieldMask), nil + } + // slow-path: unmarshal from json + var fm = new(FieldMask) + err := fm.UnmarshalJSON(data) + if err != nil { + return nil, err + } + json2fm.Store(string(data), fm) + return fm, nil +} diff --git a/fieldmask/storage.go b/fieldmask/storage.go index 8224a6db..ed70e622 100644 --- a/fieldmask/storage.go +++ b/fieldmask/storage.go @@ -16,32 +16,18 @@ package fieldmask -import ( - "github.com/cloudwego/thriftgo/thrift_reflection" -) - type fieldID int32 -const _MaxFieldIDHead = 127 +const _MaxFieldIDHead = 63 type fieldMap struct { head [_MaxFieldIDHead + 1]*FieldMask tail map[fieldID]*FieldMask } -func makeFieldMaskMap(st *thrift_reflection.StructDescriptor) fieldMap { - max := 0 - count := 0 - for _, f := range st.GetFields() { - if max < int(f.GetID()) { - max = int(f.GetID()) - count = 0 - } else { - count += 1 - } - } +func makeFieldMaskMap() fieldMap { return fieldMap{ - tail: make(map[fieldID]*FieldMask, count), + tail: make(map[fieldID]*FieldMask), } } @@ -65,7 +51,7 @@ func (fm *fieldMap) Reset() { // self.tail = self.tail[:0] // } -func (self *fieldMap) SetIfNotExist(f fieldID, ft fieldMaskType) (s *FieldMask) { +func (self *fieldMap) SetIfNotExist(f fieldID, ft FieldMaskType) (s *FieldMask) { if f <= _MaxFieldIDHead { s = self.head[f] if s == nil { @@ -101,13 +87,13 @@ func (self *fieldMap) Get(f fieldID) (ret *FieldMask) { } // setFieldID ensure a fieldmask slot for f -func (self *FieldMask) setFieldID(f fieldID, st *thrift_reflection.StructDescriptor) *FieldMask { +func (self *FieldMask) setFieldID(f fieldID, ft FieldMaskType) *FieldMask { if self.fdMask == nil { // println("new fdmask") - m := makeFieldMaskMap(st) + m := makeFieldMaskMap() self.fdMask = &m } - return self.fdMask.SetIfNotExist(fieldID(f), switchFt(st.GetFieldById(int32(f)).GetType())) + return self.fdMask.SetIfNotExist(fieldID(f), ft) } // type fieldMaskBitmap []byte @@ -138,10 +124,10 @@ func (self *FieldMask) setFieldID(f fieldID, st *thrift_reflection.StructDescrip // return ((*self)[b] & byte(1< 0 { + self.all.print(buf, indent+2, fs[0].GetType()) + } + return + } for _, f := range st.GetFields() { if _, exist := self.Field(int16(f.GetID())); !exist { continue } self.printField(buf, indent+2, f) } - } else if self.typ == ftArray { + } else if self.typ == FtList || self.typ == FtIntMap { + if self.All() { + printIndent(buf, indent+2, "*\n") + self.all.printElem(buf, indent+2, 0, desc.GetValueType()) + return + } for k, v := range self.intMask { if v.typ == 0 { continue } self.printElem(buf, indent+2, k, desc.GetValueType()) } - printIndent(buf, indent, "]\n") - } else if self.typ == ftIntMap || self.typ == ftStrMap { - for k, v := range self.intMask { + } else if self.typ == FtStrMap { + if self.All() { + printIndent(buf, indent+2, "*") + self.printElem(buf, indent+2, "", desc.GetValueType()) + return + } + for k, v := range self.strMask { if v.typ == 0 { continue } self.printElem(buf, indent+2, k, desc.GetValueType()) } - printIndent(buf, indent, "}\n") + } else if self.typ == FtScalar { + buf.WriteString(" (") + buf.WriteString(desc.GetName()) + buf.WriteString(")\n") } else { printIndent(buf, indent, "Unknown Fieldmask") } diff --git a/generator/golang/option.go b/generator/golang/option.go index 45553456..845fffb2 100644 --- a/generator/golang/option.go +++ b/generator/golang/option.go @@ -56,8 +56,10 @@ type Features struct { KeepCodeRefName bool `keep_code_ref_name:"Genenerate code ref but still keep file name."` TrimIDL bool `trim_idl:"Simplify IDL to the most concise form before generating code."` - JSONStringer bool `json_stringer:"Generate the JSON marshal method in String() method."` - WithFieldMask bool `with_field_mask:"Support field-mask for generated code."` + JSONStringer bool `json_stringer:"Generate the JSON marshal method in String() method."` + + WithFieldMask bool `with_field_mask:"Support field-mask for generated code."` + FieldMaskHalfway bool `field_mask_halfway:"Support set field-mask on not-root struct."` } var defaultFeatures = Features{ @@ -90,6 +92,7 @@ var defaultFeatures = Features{ TrimIDL: false, JSONStringer: false, WithFieldMask: false, + FieldMaskHalfway: false, } type param struct { diff --git a/generator/golang/read_write_context.go b/generator/golang/read_write_context.go index c23c2e0f..045d4474 100644 --- a/generator/golang/read_write_context.go +++ b/generator/golang/read_write_context.go @@ -63,6 +63,10 @@ func (c *ReadWriteContext) WithFieldMask(fm string) *ReadWriteContext { return c } +func (c *ReadWriteContext) NeedFieldMask() bool { + return c.FieldMask != "" +} + // WithTarget sets the target name. func (c *ReadWriteContext) WithTarget(t string) *ReadWriteContext { c.Target = t @@ -128,6 +132,5 @@ func mkRWCtx(r *Resolver, s *Scope, t *parser.Type, top *ReadWriteContext) (*Rea } } - ctx.FieldMask = "fm" return ctx, nil } diff --git a/generator/golang/templates/processor.go b/generator/golang/templates/processor.go index 5f10acf5..542d1720 100644 --- a/generator/golang/templates/processor.go +++ b/generator/golang/templates/processor.go @@ -170,10 +170,14 @@ func (p *{{$ProcessName}}) Process(ctx context.Context, seqId int32, iprot, opro {{- range .Functions}} {{$ArgsType := .ArgType}} +{{- $withFieldMask := (SetWithFieldMask false) }} {{template "StructLike" $ArgsType}} +{{- $_ := (SetWithFieldMask $withFieldMask) }} {{- if not .Oneway}} {{$ResType := .ResType}} + {{- $withFieldMask := (SetWithFieldMask false) }} {{template "StructLike" $ResType}} + {{- $_ := (SetWithFieldMask $withFieldMask) }} {{- end}} {{- end}}{{/* range .Functions */}} {{- end}}{{/* define "Processor" */}} diff --git a/generator/golang/templates/struct.go b/generator/golang/templates/struct.go index b19ce7a6..9e21efc5 100644 --- a/generator/golang/templates/struct.go +++ b/generator/golang/templates/struct.go @@ -82,12 +82,28 @@ func (p *{{$TypeName}}) CarryingUnknownFields() bool { {{end}}{{/* if Features.KeepUnknownFields */}} {{if Features.WithFieldMask}} -func (p *{{$TypeName}}) GetFieldMask() *fieldmask.FieldMask { +func (p *{{$TypeName}}) Get_FieldMask() *fieldmask.FieldMask { + if p == nil { + return nil + } return p._fieldmask } -func (p *{{$TypeName}}) SetFieldMask(fm *fieldmask.FieldMask) { + +func (p *{{$TypeName}}) Set_FieldMask(fm *fieldmask.FieldMask) { + if p == nil { + return + } + p._fieldmask = fm +} + +{{- if Features.FieldMaskHalfway}} +func (p *{{$TypeName}}) Pass_FieldMask(fm *fieldmask.FieldMask) { + if p == nil || p._fieldmask != nil { + return + } p._fieldmask = fm } +{{- end}} {{end}}{{/* if Features.WithFieldMask */}} var fieldIDToName_{{$TypeName}} = map[int16]string{ @@ -179,21 +195,13 @@ func (p *{{$TypeName}}) Read(iprot thrift.TProtocol) (err error) { {{- $isBaseVal := .Type | IsBaseType}} case {{.ID}}: if fieldTypeId == thrift.{{.Type | GetTypeIDConstant }} { - {{- if Features.WithFieldMask}} - if {{if $isBaseVal}}_{{else}}nfm{{end}}, ex := p._fieldmask.Field(fieldId); ex { - {{- end}} - if err = p.{{.Reader}}(iprot{{if and Features.WithFieldMask (not $isBaseVal)}}, nfm{{end}}); err != nil { + if err = p.{{.Reader}}(iprot); err != nil { goto ReadFieldError } {{- if .Requiredness.IsRequired}} isset{{.GoName}} = true {{- end}} - break - {{- if Features.WithFieldMask}} - } - {{- end}} - } - if err = iprot.Skip(fieldTypeId); err != nil { + } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } {{- end}}{{/* range .Fields */}} @@ -279,9 +287,17 @@ var StructLikeReadField = ` {{- range .Fields}} {{$FieldName := .GoName}} {{- $isBaseVal := .Type | IsBaseType -}} -func (p *{{$TypeName}}) {{.Reader}}(iprot thrift.TProtocol{{if and Features.WithFieldMask (not $isBaseVal)}}, fm *fieldmask.FieldMask{{end}}) error { +func (p *{{$TypeName}}) {{.Reader}}(iprot thrift.TProtocol) error { + {{- if Features.WithFieldMask}} + if {{if $isBaseVal}}_{{else}}fm{{end}}, ex := p._fieldmask.Field({{.ID}}); ex { + {{- end}} {{$ctx := (MkRWCtx .).WithFieldMask "fm"}} {{- template "FieldRead" $ctx}} + {{- if Features.WithFieldMask}} + } else if err := iprot.Skip(thrift.{{.Type | GetTypeIDConstant}}); err != nil { + return err + } + {{- end}} return nil } {{- end}}{{/* range .Fields */}} @@ -308,17 +324,11 @@ func (p *{{$TypeName}}) Write(oprot thrift.TProtocol) (err error) { } if p != nil { {{- range .Fields}} - {{- $isBaseVal := .Type | IsBaseType}} - {{- if Features.WithFieldMask}} - if {{if $isBaseVal}}_{{else}}nfm{{end}}, ex := p._fieldmask.Field({{.ID}}); ex { - {{- end}} - if err = p.{{.Writer}}(oprot{{if and Features.WithFieldMask (not $isBaseVal)}}, nfm{{end}}); err != nil { + if err = p.{{.Writer}}(oprot); err != nil { fieldId = {{.ID}} goto WriteFieldError } - {{- if Features.WithFieldMask}} - } - {{- end}} + {{- end}}{{/* range .Fields */}} {{- if Features.KeepUnknownFields}} if err = p._unknownFields.Write(oprot); err != nil { @@ -364,11 +374,14 @@ var StructLikeWriteField = ` {{- $FieldName := .GoName}} {{- $IsSetName := .IsSetter}} {{- $TypeID := .Type | GetTypeIDConstant }} -{{- $isBaseVal := .Type | IsBaseType -}} -func (p *{{$TypeName}}) {{.Writer}}(oprot thrift.TProtocol{{if and Features.WithFieldMask (not $isBaseVal)}}, fm *fieldmask.FieldMask{{end}}) (err error) { +{{- $isBaseVal := .Type | IsBaseType }} +func (p *{{$TypeName}}) {{.Writer}}(oprot thrift.TProtocol) (err error) { {{- if .Requiredness.IsOptional}} if p.{{$IsSetName}}() { {{- end}} + {{- if Features.WithFieldMask}} + if {{if $isBaseVal}}_{{else}}fm{{end}}, ex := p._fieldmask.Field({{.ID}}); ex { + {{- end}} if err = oprot.WriteFieldBegin("{{.Name}}", thrift.{{$TypeID}}, {{.ID}}); err != nil { goto WriteFieldBeginError } @@ -377,6 +390,21 @@ func (p *{{$TypeName}}) {{.Writer}}(oprot thrift.TProtocol{{if and Features.With if err = oprot.WriteFieldEnd(); err != nil { goto WriteFieldEndError } + {{- if Features.WithFieldMask}} + {{- if .Requiredness.IsRequired}} + } else { + if err = oprot.WriteFieldBegin("{{.Name}}", thrift.{{$TypeID}}, {{.ID}}); err != nil { + goto WriteFieldBeginError + } + {{ ZeroWriter .Type "oprot" "WriteFieldBeginError" }} + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + {{- else}} + } + {{- end}} + {{- end}} {{- if .Requiredness.IsOptional}} } {{- end}} @@ -506,7 +534,13 @@ var FieldRead = ` var FieldReadStructLike = ` {{define "FieldReadStructLike"}} {{- .Target}} {{if .NeedDecl}}:{{end}}= {{.TypeName.Deref.NewFunc}}() - {{if Features.WithFieldMask}}{{.Target}}.SetFieldMask({{.FieldMask}}){{end}} + {{- if and (Features.WithFieldMask) .NeedFieldMask}} + {{- if Features.FieldMaskHalfway}} + {{.Target}}.Pass_FieldMask({{.FieldMask}}) + {{- else}} + {{.Target}}.Set_FieldMask({{.FieldMask}}) + {{- end}} + {{- end}} if err := {{.Target}}.Read(iprot); err != nil { return err } @@ -571,8 +605,8 @@ var FieldReadMap = ` {{- $ctx := .KeyCtx.WithDecl.WithTarget $key}} {{- template "FieldRead" $ctx}} {{- if Features.WithFieldMask}} - {{- if $isIntKey}} {{- $curFieldMask = "nfm"}} + {{- if $isIntKey}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(int({{$key}})); !ex { if err := iprot.Skip(thrift.{{.ValCtx.Type | GetTypeIDConstant}}); err != nil { return err @@ -580,7 +614,6 @@ var FieldReadMap = ` continue } else { {{- else if $isStrKey}} - {{- $curFieldMask = "nfm"}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Str(string({{$key}})); !ex { if err := iprot.Skip(thrift.{{.ValCtx.Type | GetTypeIDConstant}}); err != nil { return err @@ -588,7 +621,12 @@ var FieldReadMap = ` continue } else { {{- else}} - {{$curFieldMask}} = nil + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(0); !ex { + if err := iprot.Skip(thrift.{{.ValCtx.Type | GetTypeIDConstant}}); err != nil { + return err + } + continue + } else { {{- end}} {{- end}}{{/* end WithFieldMask */}} {{/* line break */}} @@ -601,7 +639,7 @@ var FieldReadMap = ` {{end}} {{.Target}}[{{$key}}] = {{$val}} - {{- if and Features.WithFieldMask (or $isIntKey $isStrKey)}} + {{- if and Features.WithFieldMask}} } {{- end}} } @@ -705,10 +743,12 @@ var FieldWrite = ` // FieldWriteStructLike . var FieldWriteStructLike = ` {{define "FieldWriteStructLike"}} - {{- if Features.WithFieldMask}} - if {{.Target}} != nil { - {{.Target}}.SetFieldMask({{.FieldMask}}) - } + {{- if and (Features.WithFieldMask) .NeedFieldMask}} + {{- if Features.FieldMaskHalfway}} + {{.Target}}.Pass_FieldMask({{.FieldMask}}) + {{- else}} + {{.Target}}.Set_FieldMask({{.FieldMask}}) + {{- end}} {{- end}} if err := {{.Target}}.Write(oprot); err != nil { return err @@ -749,7 +789,7 @@ var FieldWriteMap = ` {{- $isStrKey := .KeyCtx.Type | IsStrType -}} {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask -}} - {{- if and Features.WithFieldMask (or $isIntKey $isStrKey) }} + {{- if and Features.WithFieldMask (or $isStrKey $isIntKey) }} if !{{.FieldMask}}.All() { l := len({{.Target}}) for k := range {{.Target}} { @@ -787,26 +827,27 @@ var FieldWriteMap = ` {{- end}} for k, v := range {{.Target}} { {{- if Features.WithFieldMask}} - {{- if $isIntKey}} {{- $curFieldMask = "nfm"}} + {{- if $isIntKey}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(int(k)); !ex { continue } else { {{- else if $isStrKey}} - {{- $curFieldMask = "nfm"}} ks := string(k) if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Str(ks); !ex { continue } else { {{- else}} - {{$curFieldMask}} = nil + if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(0); !ex { + continue + } else { {{- end}} {{- end}}{{/* end Features.WithFieldMask */}} {{- $ctx := .KeyCtx.WithTarget "k" -}} {{- template "FieldWrite" $ctx}} {{- $ctx := (.ValCtx.WithTarget "v").WithFieldMask $curFieldMask -}} {{- template "FieldWrite" $ctx}} - {{- if and Features.WithFieldMask (or $isIntKey $isStrKey)}} + {{- if and Features.WithFieldMask }} } {{- end}} } diff --git a/generator/golang/thrift.go b/generator/golang/thrift.go index c7b8aea7..d8be4faf 100644 --- a/generator/golang/thrift.go +++ b/generator/golang/thrift.go @@ -51,7 +51,47 @@ func IsBaseType(t *parser.Type) bool { return false } +func checkErrorTPL(assign string, err string) string { + return "if err := " + assign + "; err != nil {\n goto " + err + "\n}\n" +} + // IsBaseType determines whether the given type is a base type. +func ZeroWriter(t *parser.Type, oprot string, err string) string { + switch t.GetCategory() { + case parser.Category_Bool: + return checkErrorTPL(oprot+".WriteBool(false)", err) + case parser.Category_Byte: + return checkErrorTPL(oprot+".WriteByte(0)", err) + case parser.Category_I16: + return checkErrorTPL(oprot+".WriteI16(0)", err) + case parser.Category_Enum, parser.Category_I32: + return checkErrorTPL(oprot+".WriteI32(0)", err) + case parser.Category_I64: + return checkErrorTPL(oprot+".WriteI64(0)", err) + case parser.Category_Double: + return checkErrorTPL(oprot+".WriteDouble(0)", err) + case parser.Category_String: + return checkErrorTPL(oprot+".WriteString(\"\")", err) + case parser.Category_Binary: + return checkErrorTPL(oprot+".WriteBinary([]byte{})", err) + case parser.Category_Map: + return checkErrorTPL(oprot+".WriteMapBegin(thrift."+GetTypeIDConstant(t.GetKeyType())+ + ",thrift."+GetTypeIDConstant(t.GetValueType())+",0)", err) + checkErrorTPL(oprot+".WriteMapEnd()", err) + case parser.Category_List: + return checkErrorTPL(oprot+".WriteListBegin(thrift."+GetTypeIDConstant(t.GetValueType())+ + ",0)", err) + checkErrorTPL(oprot+".WriteListEnd()", err) + case parser.Category_Set: + return checkErrorTPL(oprot+".WriteSetBegin(thrift."+GetTypeIDConstant(t.GetValueType())+ + ",0)", err) + checkErrorTPL(oprot+".WriteSetEnd()", err) + case parser.Category_Struct: + return checkErrorTPL(oprot+".WriteStructBegin(\"\")", err) + checkErrorTPL(oprot+".WriteFieldStop()", err) + + checkErrorTPL(oprot+".WriteStructEnd()", err) + default: + panic("unsuported type zero writer for" + t.Name) + } +} + +// IsIntType determines whether the given type is a Int type. func IsIntType(t *parser.Type) bool { switch t.Category { case parser.Category_Byte, parser.Category_I16, parser.Category_I32, parser.Category_I64, parser.Category_Enum: @@ -61,7 +101,7 @@ func IsIntType(t *parser.Type) bool { } } -// IsBaseType determines whether the given type is a base type. +// IsStrType determines whether the given type is a Str type. func IsStrType(t *parser.Type) bool { switch t.Category { case parser.Category_String, parser.Category_Binary: diff --git a/generator/golang/util.go b/generator/golang/util.go index 6ddced83..7c56f525 100644 --- a/generator/golang/util.go +++ b/generator/golang/util.go @@ -85,6 +85,12 @@ func (cu *CodeUtils) SetFeatures(fs Features) { cu.features = fs } +func (cu *CodeUtils) SetWithFieldMask(enable bool) bool { + ret := cu.features.WithFieldMask + cu.features.WithFieldMask = enable + return ret +} + // Features returns the current settings of generator features. func (cu *CodeUtils) Features() Features { return cu.features @@ -371,16 +377,18 @@ func (cu *CodeUtils) BuildFuncMap() template.FuncMap { "InsertionPoint": plugin.InsertionPoint, "Unexport": common.Unexport, - "Debug": cu.Debug, - "Features": cu.Features, - "GetPackageName": cu.GetPackageName, - "GenTags": cu.GenTags, - "GenFieldTags": cu.GenFieldTags, + "Debug": cu.Debug, + "Features": cu.Features, + "SetWithFieldMask": cu.SetWithFieldMask, + "GetPackageName": cu.GetPackageName, + "GenTags": cu.GenTags, + "GenFieldTags": cu.GenFieldTags, "MkRWCtx": func(f *Field) (*ReadWriteContext, error) { return cu.MkRWCtx(cu.rootScope, f) }, "IsBaseType": IsBaseType, + "ZeroWriter": ZeroWriter, "NeedRedirect": NeedRedirect, "IsFixedLengthType": IsFixedLengthType, "SupportIsSet": SupportIsSet, diff --git a/internal/utils/b2s.go b/internal/utils/b2s.go new file mode 100644 index 00000000..8fb5fc5b --- /dev/null +++ b/internal/utils/b2s.go @@ -0,0 +1,41 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "unsafe" + +type sliceHeader struct { + p unsafe.Pointer + l int + c int +} + +type stringHeader struct { + p unsafe.Pointer + l int +} + +func B2S(b []byte) (s string) { + (*stringHeader)(unsafe.Pointer(&s)).p = (*sliceHeader)(unsafe.Pointer(&b)).p + (*stringHeader)(unsafe.Pointer(&s)).l = (*sliceHeader)(unsafe.Pointer(&b)).l + return +} + +func S2B(s string) (b []byte) { + (*sliceHeader)(unsafe.Pointer(&b)).p = (*stringHeader)(unsafe.Pointer(&s)).p + (*sliceHeader)(unsafe.Pointer(&b)).l = (*stringHeader)(unsafe.Pointer(&s)).l + (*sliceHeader)(unsafe.Pointer(&b)).c = (*stringHeader)(unsafe.Pointer(&s)).l + return +} diff --git a/test/golang/cases_and_options/go.mod b/test/golang/cases_and_options/go.mod index 34e0bf8a..707d2c7f 100644 --- a/test/golang/cases_and_options/go.mod +++ b/test/golang/cases_and_options/go.mod @@ -2,4 +2,9 @@ module example.com/test go 1.17 -require github.com/apache/thrift v0.13.0 +require ( + github.com/apache/thrift v0.13.0 + github.com/cloudwego/thriftgo v0.0.0-00010101000000-000000000000 +) + +replace github.com/cloudwego/thriftgo => ../../.. diff --git a/test/golang/cases_and_options/go.sum b/test/golang/cases_and_options/go.sum index 171755b6..cbc2820a 100644 --- a/test/golang/cases_and_options/go.sum +++ b/test/golang/cases_and_options/go.sum @@ -1,2 +1,29 @@ github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/test/golang/cases_and_options/run_test.sh b/test/golang/cases_and_options/run_test.sh index b670bee1..53aee5a5 100755 --- a/test/golang/cases_and_options/run_test.sh +++ b/test/golang/cases_and_options/run_test.sh @@ -71,6 +71,7 @@ run_cases() { } run() { + go mod edit -replace github.com/cloudwego/thriftgo=../../.. out=$(run_cases 2>$errors) if [ -n "$out" ]; then echo diff --git a/test/golang/fieldmask/a.thrift b/test/golang/fieldmask/a.thrift index 5016377e..72b97d35 100644 --- a/test/golang/fieldmask/a.thrift +++ b/test/golang/fieldmask/a.thrift @@ -1,5 +1,3 @@ -#! /bin/bash -e - # Copyright 2022 CloudWeGo Authors # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,7 +22,7 @@ struct TrafficEnv { } struct Base { - 0: string Addr = "", + 0: required string Addr = "", 1: string LogID = "", 2: string Caller = "", 5: optional TrafficEnv TrafficEnv, @@ -59,6 +57,7 @@ typedef Val Key struct Val { 1: string id + 2: string name } typedef double Float @@ -74,17 +73,30 @@ enum Ex { } struct BaseResp { - 1: string StatusMessage = "", - 2: i32 StatusCode = 0, - 3: optional map Extra, + 1: required string StatusMessage = "", + 2: required i32 StatusCode = 0, + 3: required bool R3, + 4: required byte R4, + 5: required i16 R5, + 6: required i64 R6, + 7: required double R7, + 8: required string R8, + 9: required Ex R9, + 10: required list R10, + 11: required set R11, + 12: required TrafficEnv R12, + 13: required map R13, + 0: required Key R0, - 4: map F1 - 5: map F2, - 6: list F3 - 7: set F4, - 8: map F5 - 9: map F6 - 10: map F7 - 11: map> F8 - 12: list>> F9 + 14: map F1 + 15: map F2, + 16: list F3 + 17: set F4, + 18: map F5 + 19: map F6 + 110: map F7 + 111: map> F8 + 112: list>> F9 + 113: map F10 } + diff --git a/test/golang/fieldmask/b.thrift b/test/golang/fieldmask/b.thrift new file mode 100644 index 00000000..72b97d35 --- /dev/null +++ b/test/golang/fieldmask/b.thrift @@ -0,0 +1,102 @@ +# Copyright 2022 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +namespace go base + +struct TrafficEnv { + 0: string Name = "", + 1: bool Open = false, + 2: string Env = "", + 256: i64 Code, +} + +struct Base { + 0: required string Addr = "", + 1: string LogID = "", + 2: string Caller = "", + 5: optional TrafficEnv TrafficEnv, + 255: optional ExtraInfo Extra, + 256: MetaInfo Meta, +} + +struct ExtraInfo { + 1: map F1 + 2: map F2, + 3: list F3 + 4: set F4, + 5: map F5 + 6: map F6 + 7: map> F7 + 8: map> F8 + 9: map>> F9 + 10: map F10 +} + +struct MetaInfo { + 1: map IntMap, + 2: map StrMap, + 3: list List, + 4: set Set, + 11: map> MapList + 12: list>> ListMapList + 255: Base Base, +} + +typedef Val Key + +struct Val { + 1: string id + 2: string name +} + +typedef double Float + +typedef i64 Int + +typedef string Str + +enum Ex { + A = 1, + B = 2, + C = 3 +} + +struct BaseResp { + 1: required string StatusMessage = "", + 2: required i32 StatusCode = 0, + 3: required bool R3, + 4: required byte R4, + 5: required i16 R5, + 6: required i64 R6, + 7: required double R7, + 8: required string R8, + 9: required Ex R9, + 10: required list R10, + 11: required set R11, + 12: required TrafficEnv R12, + 13: required map R13, + 0: required Key R0, + + 14: map F1 + 15: map F2, + 16: list F3 + 17: set F4, + 18: map F5 + 19: map F6 + 110: map F7 + 111: map> F8 + 112: list>> F9 + 113: map F10 +} + diff --git a/test/golang/fieldmask/gen_test.go b/test/golang/fieldmask/gen_test.go new file mode 100644 index 00000000..a4ee0544 --- /dev/null +++ b/test/golang/fieldmask/gen_test.go @@ -0,0 +1,44 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fieldmask + +// import ( +// "testing" + +// "github.com/cloudwego/thriftgo/internal/test_util" +// "github.com/cloudwego/thriftgo/plugin" +// ) + +// func TestGen(t *testing.T) { +// g, r := test_util.GenerateGolang("a.thrift", "gen-old/", nil, nil) +// if err := g.Persist(r); err != nil { +// panic(err) +// } +// g, r = test_util.GenerateGolang("a.thrift", "gen-new/", []plugin.Option{ +// {"with_field_mask", ""}, +// {"with_reflection", ""}, +// }, nil) +// if err := g.Persist(r); err != nil { +// panic(err) +// } +// g, r = test_util.GenerateGolang("b.thrift", "gen-halfway/", []plugin.Option{ +// {"with_field_mask", ""}, +// {"field_mask_halfway", ""}, +// {"with_reflection", ""}, +// }, nil) +// if err := g.Persist(r); err != nil { +// panic(err) +// } +// } diff --git a/test/golang/fieldmask/go.mod b/test/golang/fieldmask/go.mod index f5535937..846dc68a 100644 --- a/test/golang/fieldmask/go.mod +++ b/test/golang/fieldmask/go.mod @@ -14,8 +14,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dlclark/regexp2 v1.10.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/text v0.6.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/test/golang/fieldmask/go.sum b/test/golang/fieldmask/go.sum index 72c5058f..0c318a0f 100644 --- a/test/golang/fieldmask/go.sum +++ b/test/golang/fieldmask/go.sum @@ -2,7 +2,6 @@ github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -27,7 +26,6 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/test/golang/fieldmask/main_test.go b/test/golang/fieldmask/main_test.go index 9132762f..03a1be79 100644 --- a/test/golang/fieldmask/main_test.go +++ b/test/golang/fieldmask/main_test.go @@ -15,29 +15,225 @@ package fieldmask import ( + "fmt" + "runtime" + "sync" "testing" "github.com/apache/thrift/lib/go/thrift" "github.com/cloudwego/thriftgo/fieldmask" - "github.com/cloudwego/thriftgo/internal/test_util" - "github.com/cloudwego/thriftgo/plugin" + + hbase "github.com/cloudwego/thriftgo/test/golang/fieldmask/gen-halfway/base" nbase "github.com/cloudwego/thriftgo/test/golang/fieldmask/gen-new/base" obase "github.com/cloudwego/thriftgo/test/golang/fieldmask/gen-old/base" "github.com/stretchr/testify/require" ) -func TestGen(t *testing.T) { - g, r := test_util.GenerateGolang("a.thrift", "gen-old/", nil, nil) - if err := g.Persist(r); err != nil { +var fieldmaskCache sync.Map + +func initFielMask() { + // new a obj to get its TypeDescriptor + obj := nbase.NewBase() + + // construct a fieldmask with TypeDescriptor and thrift pathes + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), + "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", + "$.Meta.List[1]", "$.Meta.Set[0].id", "$.Meta.Set[1].name") + if err != nil { panic(err) } - g, r = test_util.GenerateGolang("a.thrift", "gen-new/", []plugin.Option{ - {"with_field_mask", ""}, - {"with_reflection", ""}, - }, nil) - if err := g.Persist(r); err != nil { - panic(err) + + // cache it for future usage of nbase.Base + fieldmaskCache.Store("Mask1ForBase", fm) +} + +func TestFieldMask_Write(t *testing.T) { + initFielMask() + // biz logic: handle and get final response object + obj := SampleNewBase() + + // Load fieldmask from cache + fm, _ := fieldmaskCache.Load("Mask1ForBase") + if fm != nil { + // load ok, set fieldmask onto the object using codegen API + obj.Set_FieldMask(fm.(*fieldmask.FieldMask)) + } + + // return obj + + // prepare buffer + buf := thrift.NewTMemoryBufferLen(1024) + prot := thrift.NewTBinaryProtocol(buf, true, true) + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + + // validate output + obj2 := nbase.NewBase() + err := obj2.Read(prot) + if err != nil { + t.Fatal(err) + } + require.Equal(t, "", obj2.Addr) + require.Equal(t, obj.LogID, obj2.LogID) + require.Equal(t, "", obj2.Caller) + require.Equal(t, "", obj2.TrafficEnv.Name) + require.Equal(t, false, obj2.TrafficEnv.Open) + require.Equal(t, "", obj2.TrafficEnv.Env) + require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) + require.Equal(t, obj.Meta.IntMap[1].ID, obj2.Meta.IntMap[1].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) + require.Equal(t, obj.Meta.StrMap["1234"].ID, obj2.Meta.StrMap["1234"].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "b", obj2.Meta.List[0].ID) + require.Equal(t, "b", obj2.Meta.List[0].Name) + require.Equal(t, 2, len(obj2.Meta.Set)) + require.Equal(t, "a", obj2.Meta.Set[0].ID) + require.Equal(t, "", obj2.Meta.Set[0].Name) + require.Equal(t, "", obj2.Meta.Set[1].ID) + require.Equal(t, "b", obj2.Meta.Set[1].Name) +} + +func TestFieldMask_Read(t *testing.T) { + initFielMask() + obj := SampleNewBase() + buf := thrift.NewTMemoryBufferLen(1024) + prot := thrift.NewTBinaryProtocol(buf, true, true) + + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + + obj2 := nbase.NewBase() + fm, _ := fieldmaskCache.Load("Mask1ForBase") + if fm != nil { + obj2.Set_FieldMask(fm.(*fieldmask.FieldMask)) + } + + err := obj2.Read(prot) + if err != nil { + t.Fatal(err) } + + require.Equal(t, "", obj2.Addr) + require.Equal(t, obj.LogID, obj2.LogID) + require.Equal(t, "", obj2.Caller) + require.Equal(t, "", obj2.TrafficEnv.Name) + require.Equal(t, false, obj2.TrafficEnv.Open) + require.Equal(t, "", obj2.TrafficEnv.Env) + require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) + require.Equal(t, obj.Meta.IntMap[1].ID, obj2.Meta.IntMap[1].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) + require.Equal(t, obj.Meta.StrMap["1234"].ID, obj2.Meta.StrMap["1234"].ID) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "b", obj2.Meta.List[0].ID) + require.Equal(t, "b", obj2.Meta.List[0].Name) + require.Equal(t, 2, len(obj2.Meta.Set)) + require.Equal(t, "a", obj2.Meta.Set[0].ID) + require.Equal(t, "", obj2.Meta.Set[0].Name) + require.Equal(t, "", obj2.Meta.Set[1].ID) + require.Equal(t, "b", obj2.Meta.Set[1].Name) +} + +func TestMaskRequired(t *testing.T) { + fm, err := fieldmask.NewFieldMask(nbase.NewBaseResp().GetTypeDescriptor(), "$.F1", "$.F8") + if err != nil { + t.Fatal(err) + } + j, err := fm.MarshalJSON() + if err != nil { + t.Fatal(err) + } + println(string(j)) + nf, ex := fm.Field(111) + if !ex { + t.Fatal(nf) + } + + t.Run("read", func(t *testing.T) { + obj := nbase.NewBaseResp() + obj.F1 = map[nbase.Str]nbase.Str{"a": "b"} + obj.F8 = map[float64][]nbase.Str{1.0: []nbase.Str{"a"}} + buf := thrift.NewTMemoryBufferLen(1024) + prot := thrift.NewTBinaryProtocol(buf, true, true) + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + obj2 := nbase.NewBaseResp() + obj2.Set_FieldMask(fm) + if err := obj2.Read(prot); err != nil { + t.Fatal(err) + } + require.Equal(t, obj.F1, obj2.F1) + require.Equal(t, obj.F8, obj2.F8) + }) + + t.Run("write", func(t *testing.T) { + obj := nbase.NewBaseResp() + obj.F1 = map[nbase.Str]nbase.Str{"a": "b"} + obj.F8 = map[float64][]nbase.Str{1.0: []nbase.Str{"a"}} + obj.Set_FieldMask(fm) + buf := thrift.NewTMemoryBufferLen(1024) + prot := thrift.NewTBinaryProtocol(buf, true, true) + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + // data := []byte(buf.String()) + // v, err := dg.NewNode(dt.STRUCT, data).Interface(&dg.Options{}) + // if err != nil { + // t.Fatal(err) + // } + // spew.Dump(v) + + obj2 := nbase.NewBaseResp() + if err := obj2.Read(prot); err != nil { + t.Fatal(err) + } + fmt.Printf("%#v\n", obj2) + }) +} + +func TestSetMaskHalfway(t *testing.T) { + obj := hbase.NewBase() + obj.Extra = hbase.NewExtraInfo() + obj.Extra.F1 = map[string]string{"a": "b"} + obj.Extra.F8 = map[int64][]*hbase.Key{1: []*hbase.Key{hbase.NewKey()}} + + fm, err := fieldmask.NewFieldMask(obj.Extra.GetTypeDescriptor(), "$.F1") + if err != nil { + t.Fatal(err) + } + obj.Extra.Set_FieldMask(fm) + buf := thrift.NewTMemoryBufferLen(1024) + prot := thrift.NewTBinaryProtocol(buf, true, true) + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + + obj2 := hbase.NewBase() + if err := obj2.Read(prot); err != nil { + t.Fatal(err) + } + require.Equal(t, obj.Extra.F1, obj2.Extra.F1) + require.Equal(t, map[int64][]*hbase.Key(nil), obj2.Extra.F8) + + fm, err = fieldmask.NewFieldMask(obj.Extra.GetTypeDescriptor(), "$.F8") + if err != nil { + t.Fatal(err) + } + obj.Extra.Set_FieldMask(fm) + if err := obj.Write(prot); err != nil { + t.Fatal(err) + } + + obj2 = hbase.NewBase() + if err := obj2.Read(prot); err != nil { + t.Fatal(err) + } + require.Equal(t, map[string]string(nil), obj2.Extra.F1) + require.Equal(t, obj.Extra.F8, obj2.Extra.F8) } func SampleNewBase() *nbase.Base { @@ -56,16 +252,25 @@ func SampleNewBase() *nbase.Base { } v0 := nbase.NewVal() v0.ID = "a" + v0.Name = "a" v1 := nbase.NewVal() v1.ID = "b" + v1.Name = "b" obj.Meta.List = []*nbase.Val{v0, v1} + // v0 = nbase.NewVal() + // v0.ID = "a" + // v0.Name = "a" + // v1 = nbase.NewVal() + // v1.ID = "b" + // v1.Name = "b" obj.Meta.Set = []*nbase.Val{v0, v1} - obj.Extra = nbase.NewExtraInfo() + // obj.Extra = nbase.NewExtraInfo() obj.TrafficEnv = nbase.NewTrafficEnv() obj.TrafficEnv.Code = 1 obj.TrafficEnv.Env = "abcd" obj.TrafficEnv.Name = "abcd" obj.TrafficEnv.Open = true + obj.Meta.Base = nbase.NewBase() return obj } @@ -85,19 +290,52 @@ func SampleOldBase() *obase.Base { } v0 := obase.NewVal() v0.ID = "a" + v0.Name = "a" v1 := obase.NewVal() v1.ID = "b" + v1.Name = "b" obj.Meta.List = []*obase.Val{v0, v1} obj.Meta.Set = []*obase.Val{v0, v1} - obj.Extra = obase.NewExtraInfo() + // obj.Extra = obase.NewExtraInfo() obj.TrafficEnv = obase.NewTrafficEnv() obj.TrafficEnv.Code = 1 obj.TrafficEnv.Env = "abcd" obj.TrafficEnv.Name = "abcd" obj.TrafficEnv.Open = true + obj.Meta.Base = obase.NewBase() return obj } +// func SampleApacheBase() *abase.Base { +// obj := abase.NewBase() +// obj.Addr = "abcd" +// obj.Caller = "abcd" +// obj.LogID = "abcd" +// obj.Meta = abase.NewMetaInfo() +// obj.Meta.StrMap = map[string]*abase.Val{ +// "abcd": abase.NewVal(), +// "1234": abase.NewVal(), +// } +// obj.Meta.IntMap = map[int64]*abase.Val{ +// 1: abase.NewVal(), +// 2: abase.NewVal(), +// } +// v0 := abase.NewVal() +// v0.ID = "a" +// v1 := abase.NewVal() +// v1.ID = "b" +// obj.Meta.List = []*abase.Val{v0, v1} +// obj.Meta.Set = []*abase.Val{v0, v1} +// // obj.Extra = abase.NewExtraInfo() +// obj.TrafficEnv = abase.NewTrafficEnv() +// obj.TrafficEnv.Code = 1 +// obj.TrafficEnv.Env = "abcd" +// obj.TrafficEnv.Name = "abcd" +// obj.TrafficEnv.Open = true +// obj.Meta.Base = abase.NewBase() +// return obj +// } + func BenchmarkWriteWithFieldMask(b *testing.B) { b.Run("old", func(b *testing.B) { obj := SampleOldBase() @@ -130,22 +368,40 @@ func BenchmarkWriteWithFieldMask(b *testing.B) { buf := thrift.NewTMemoryBufferLen(1024) t := thrift.NewTBinaryProtocol(buf, true, true) - fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.TrafficEnv.Name", "$.Meta.IntMap", "$.Meta.List[*]") if err != nil { b.Fatal(err) } for i := 0; i < b.N; i++ { - obj.SetFieldMask(fm) + obj.Set_FieldMask(fm) if err := obj.Write(t); err != nil { b.Fatal(err) } buf.Reset() } - fm.Recycle() }) } func BenchmarkReadWithFieldMask(b *testing.B) { + // b.Run("apache", func(b *testing.B) { + // obj := SampleApacheBase() + // buf := thrift.NewTMemoryBufferLen(1024) + // t := thrift.NewTBinaryProtocol(buf, true, true) + // if err := obj.Write(t); err != nil { + // b.Fatal(err) + // } + // data := []byte(buf.String()) + + // obj = abase.NewBase() + // b.ResetTimer() + // for i := 0; i < b.N; i++ { + // buf.Reset() + // buf.Write(data) + // if err := obj.Read(t); err != nil { + // b.Fatal(err) + // } + // } + // }) b.Run("old", func(b *testing.B) { obj := SampleOldBase() buf := thrift.NewTMemoryBufferLen(1024) @@ -154,8 +410,9 @@ func BenchmarkReadWithFieldMask(b *testing.B) { b.Fatal(err) } data := []byte(buf.String()) - obj = obase.NewBase() + obj = obase.NewBase() + b.ResetTimer() for i := 0; i < b.N; i++ { buf.Reset() buf.Write(data) @@ -164,7 +421,7 @@ func BenchmarkReadWithFieldMask(b *testing.B) { } } }) - + runtime.GC() b.Run("new", func(b *testing.B) { obj := SampleNewBase() buf := thrift.NewTMemoryBufferLen(1024) @@ -174,7 +431,7 @@ func BenchmarkReadWithFieldMask(b *testing.B) { } data := []byte(buf.String()) obj = nbase.NewBase() - + b.ResetTimer() for i := 0; i < b.N; i++ { buf.Reset() buf.Write(data) @@ -183,7 +440,7 @@ func BenchmarkReadWithFieldMask(b *testing.B) { } } }) - + runtime.GC() b.Run("new-mask-half", func(b *testing.B) { obj := SampleNewBase() buf := thrift.NewTMemoryBufferLen(1024) @@ -193,100 +450,18 @@ func BenchmarkReadWithFieldMask(b *testing.B) { } data := []byte(buf.String()) obj = nbase.NewBase() - - fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.TrafficEnv.Name", "$.Meta.IntMap", "$.Meta.List[*]") if err != nil { b.Fatal(err) } - + b.ResetTimer() for i := 0; i < b.N; i++ { buf.Reset() buf.Write(data) - obj.SetFieldMask(fm) + obj.Set_FieldMask(fm) if err := obj.Read(t); err != nil { b.Fatal(err) } } - - fm.Recycle() }) } - -func TestFieldmaskWrite(t *testing.T) { - obj := SampleNewBase() - buf := thrift.NewTMemoryBufferLen(1024) - prot := thrift.NewTBinaryProtocol(buf, true, true) - - fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), - "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") - if err != nil { - t.Fatal(err) - } - obj.SetFieldMask(fm) - if err := obj.Write(prot); err != nil { - t.Fatal(err) - } - - obj2 := nbase.NewBase() - err = obj2.Read(prot) - if err != nil { - t.Fatal(err) - } - - require.Equal(t, obj.Addr, obj2.Addr) - require.Equal(t, obj.LogID, obj2.LogID) - require.Equal(t, "", obj2.Caller) - require.Equal(t, "", obj2.TrafficEnv.Name) - require.Equal(t, false, obj2.TrafficEnv.Open) - require.Equal(t, "", obj2.TrafficEnv.Env) - require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) - require.Equal(t, obj.Meta.IntMap[1].ID, obj2.Meta.IntMap[1].ID) - require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) - require.Equal(t, obj.Meta.StrMap["1234"].ID, obj2.Meta.StrMap["1234"].ID) - require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) - require.Equal(t, "b", obj2.Meta.List[0].ID) - require.Equal(t, 1, len(obj2.Meta.List)) - require.Equal(t, "b", obj2.Meta.Set[0].ID) - require.Equal(t, 1, len(obj2.Meta.Set)) - fm.Recycle() -} - -func TestFieldmaskRead(t *testing.T) { - obj := SampleNewBase() - buf := thrift.NewTMemoryBufferLen(1024) - prot := thrift.NewTBinaryProtocol(buf, true, true) - - fm, err := fieldmask.GetFieldMask(obj.GetTypeDescriptor(), - "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", "$.Meta.List[1]", "$.Meta.Set[1]") - if err != nil { - t.Fatal(err) - } - - if err := obj.Write(prot); err != nil { - t.Fatal(err) - } - - obj2 := nbase.NewBase() - obj2.SetFieldMask(fm) - err = obj2.Read(prot) - if err != nil { - t.Fatal(err) - } - - require.Equal(t, obj.Addr, obj2.Addr) - require.Equal(t, obj.LogID, obj2.LogID) - require.Equal(t, "", obj2.Caller) - require.Equal(t, "", obj2.TrafficEnv.Name) - require.Equal(t, false, obj2.TrafficEnv.Open) - require.Equal(t, "", obj2.TrafficEnv.Env) - require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) - require.Equal(t, obj.Meta.IntMap[1].ID, obj2.Meta.IntMap[1].ID) - require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) - require.Equal(t, obj.Meta.StrMap["1234"].ID, obj2.Meta.StrMap["1234"].ID) - require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) - require.Equal(t, "b", obj2.Meta.List[0].ID) - require.Equal(t, 1, len(obj2.Meta.List)) - require.Equal(t, "b", obj2.Meta.Set[0].ID) - require.Equal(t, 1, len(obj2.Meta.Set)) - fm.Recycle() -} diff --git a/test/golang/fieldmask/run_test.sh b/test/golang/fieldmask/run_test.sh index 5de97a56..f13cbd77 100755 --- a/test/golang/fieldmask/run_test.sh +++ b/test/golang/fieldmask/run_test.sh @@ -1,4 +1,4 @@ -#! /bin/bash -e +#! /bin/bash # Copyright 2022 CloudWeGo Authors # @@ -15,23 +15,26 @@ # limitations under the License. generate () { - xxx=$1 - out=gen-${xxx} - opt="go:package_prefix=example.com/test/${out}" - idl=a.thrift - if [ -d ${out} ]; then - rm -rf ${out} + out=gen-$1 + opt="go:package_prefix=example.com/test/$out" + idl=$2 + if [ -d $out ]; then + rm -rf $out fi mkdir -p $out - if [[ $xxx == new ]]; then - opt=$opt,with_field_mask,with_reflection + if [ "$1" = "new" ]; then + opt="$opt,with_field_mask,with_reflection" fi - echo "thriftgo -g $opt -o ${out} $idl" - thriftgo -g "$opt" -o ${out} $idl + if [ "$1" = "halfway" ]; then + opt="$opt,with_field_mask,field_mask_halfway,with_reflection" + fi + echo "thriftgo -g $opt -o $out $idl" + thriftgo -g "$opt" -o $out $idl } -generate old -generate new +generate old a.thrift +generate new a.thrift +generate halfway b.thrift go mod tidy go test -v ./...