Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AsterDY committed Dec 1, 2023
1 parent 5527c20 commit 9f7a28b
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 64 deletions.
8 changes: 4 additions & 4 deletions fieldmask/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,17 @@ func TestNewFieldMask(t *testing.T) {
retry := true
begin:

println("fieldmask:")
println(got.String(st))
// println("fieldmask:")
// println(got.String(st))
// spew.Dump(got)

// test marshal json
println("marshal:")
// println("marshal:")
out, err := got.MarshalJSON()
if err != nil {
t.Fatal(err)
}
println(string(out))
// println(string(out))
if !json.Valid(out) {
t.Fatal("not invalid json")
}
Expand Down
18 changes: 8 additions & 10 deletions fieldmask/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path
return false
}
styp := stok.Type()
println("stoken: ", stok.String())
j, _ := cur.MarshalJSON()
println("cur mask: ", string(j), cur.isAll, cur.all)
// println("stoken: ", stok.String())
// j, _ := cur.MarshalJSON()
// println("cur mask: ", string(j), cur.isAll, cur.all)

if styp == pathTypeRoot {
continue
Expand All @@ -341,7 +341,7 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path
if err != nil {
return false
}
println("struct: ", st.Name)
// println("struct: ", st.Name)
if cur.typ != FtStruct {
return false
}
Expand All @@ -351,7 +351,7 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path
return false
}
typ := tok.Type()
println("token", tok.String())
// println("token", tok.String())

var f *thrift_reflection.FieldDescriptor
if typ == pathTypeLitInt {
Expand All @@ -375,10 +375,10 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path

// println("all", all, "FieldInMask:", cur.FieldInMask(int32(f.GetID())))
// check if name set mask
println("field ", f.GetID())
// println("field ", f.GetID())
nextFm, exist := cur.Field(int16(f.GetID()))
if !exist {
println("return false")
// println("return false")
return false
}

Expand Down Expand Up @@ -468,14 +468,12 @@ func (cur *FieldMask) PathInMask(curDesc *thrift_reflection.TypeDescriptor, path
if tok.Err() != nil {
return false
}
println("token", tok.String())
// println("token", tok.String())

if typ == pathTypeMapR {
println("break")
break
}
if cur.All() || typ == pathTypeElem {
println("continue")
continue
}
if typ == pathTypeAny {
Expand Down
8 changes: 4 additions & 4 deletions fieldmask/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func unwrapDesc(desc *thrift_reflection.TypeDescriptor) *thrift_reflection.TypeD
}

func (cur *FieldMask) addPath(path string, curDesc *thrift_reflection.TypeDescriptor) error {
println("[SetPath]: ", path)
// println("[SetPath]: ", path)

curDesc = unwrapDesc(curDesc)
if curDesc == nil {
Expand All @@ -96,7 +96,7 @@ func (cur *FieldMask) addPath(path string, curDesc *thrift_reflection.TypeDescri
return errPath(stok, "")
}
styp := stok.Type()
println("stoken: ", stok.String())
// println("stoken: ", stok.String())

if styp == pathTypeRoot {
cur.typ = switchFt(curDesc)
Expand All @@ -112,15 +112,15 @@ func (cur *FieldMask) addPath(path string, curDesc *thrift_reflection.TypeDescri
if cur.typ != FtStruct {
return errDesc(curDesc, "expect STRUCT")
}
println("struct: ", st.Name)
// println("struct: ", st.Name)

// get field name or field id
tok := it.Next()
if tok.Err() != nil {
return errPath(tok, "isn't field-name or field-id")
}
typ := tok.Type()
println("token: ", tok.String())
// println("token: ", tok.String())

all := cur.All()
if all {
Expand Down
4 changes: 2 additions & 2 deletions generator/golang/templates/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ var StructLikeWriteField = `
{{- $FieldName := .GoName}}
{{- $IsSetName := .IsSetter}}
{{- $TypeID := .Type | GetTypeIDConstant }}
{{- $isBaseVal := .Type | IsBaseType -}}
{{- $isBaseVal := .Type | IsBaseType }}
func (p *{{$TypeName}}) {{.Writer}}(oprot thrift.TProtocol) (err error) {
{{- if .Requiredness.IsOptional}}
if p.{{$IsSetName}}() {
Expand All @@ -380,7 +380,7 @@ func (p *{{$TypeName}}) {{.Writer}}(oprot thrift.TProtocol) (err error) {
if err = oprot.WriteFieldBegin("{{.Name}}", thrift.{{$TypeID}}, {{.ID}}); err != nil {
goto WriteFieldBeginError
}
{{ ZeroWriter .Type "oprot" "WriteFieldBeginError" Features.EnumAsINT32 }}
{{ ZeroWriter .Type "oprot" "WriteFieldBeginError" }}
if err = oprot.WriteFieldEnd(); err != nil {
goto WriteFieldEndError
}
Expand Down
28 changes: 2 additions & 26 deletions generator/golang/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ func checkErrorTPL(assign string, err string) string {
}

// IsBaseType determines whether the given type is a base type.
func ZeroWriter(t *parser.Type, oprot string, err string, enumAsI32 bool) string {
func ZeroWriter(t *parser.Type, oprot string, err string) string {
switch t.GetCategory() {
case parser.Category_Bool:
return checkErrorTPL(oprot+".WriteBool(false)", err)
case parser.Category_Byte:
return checkErrorTPL(oprot+".WriteByte(0)", err)
case parser.Category_I16:
return checkErrorTPL(oprot+".WriteI16(0)", err)
case parser.Category_I32:
case parser.Category_Enum, parser.Category_I32:
return checkErrorTPL(oprot+".WriteI32(0)", err)
case parser.Category_I64:
return checkErrorTPL(oprot+".WriteI64(0)", err)
Expand All @@ -74,12 +74,6 @@ func ZeroWriter(t *parser.Type, oprot string, err string, enumAsI32 bool) string
return checkErrorTPL(oprot+".WriteString(\"\")", err)
case parser.Category_Binary:
return checkErrorTPL(oprot+".WriteBinary([]byte{})", err)
case parser.Category_Enum:
if enumAsI32 {
return checkErrorTPL(oprot+".WriteI32(0)", err)
} else {
return checkErrorTPL(oprot+".WriteI64(0)", err)
}
case parser.Category_Map:
return checkErrorTPL(oprot+".WriteMapBegin(thrift."+GetTypeIDConstant(t.GetKeyType())+
",thrift."+GetTypeIDConstant(t.GetValueType())+",0)", err) + checkErrorTPL(oprot+".WriteMapEnd()", err)
Expand All @@ -97,24 +91,6 @@ func ZeroWriter(t *parser.Type, oprot string, err string, enumAsI32 bool) string
}
}

// IsBaseType determines whether the given type is a base type.
func IsFieldMaskType(t *parser.Type) bool {
switch t.Category {
case parser.Category_Bool, parser.Category_Byte, parser.Category_I16, parser.Category_I32,
parser.Category_I64, parser.Category_Double, parser.Category_Enum, parser.Category_Binary, parser.Category_String:
return false
case parser.Category_List, parser.Category_Set, parser.Category_Struct:
return true
case parser.Category_Map:
if IsStrType(t.GetKeyType()) || IsIntType(t.GetKeyType()) {
return true
}
return false
default:
panic("unexpected type:" + t.GetName())
}
}

// IsIntType determines whether the given type is a Int type.
func IsIntType(t *parser.Type) bool {
switch t.Category {
Expand Down
1 change: 0 additions & 1 deletion generator/golang/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ func (cu *CodeUtils) BuildFuncMap() template.FuncMap {
"GetTypeIDConstant": GetTypeIDConstant,
"IsIntType": IsIntType,
"IsStrType": IsStrType,
"IsFieldMaskType": IsFieldMaskType,
"UseStdLibrary": func(libs ...string) string {
cu.rootScope.imports.UseStdLibrary(libs...)
return ""
Expand Down
42 changes: 25 additions & 17 deletions test/golang/fieldmask/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,25 +158,33 @@ func TestMaskRequired(t *testing.T) {
if err := obj2.Read(prot); err != nil {
t.Fatal(err)
}
fmt.Printf("%#v\n", obj2)
require.Equal(t, obj.F1, obj2.F1)
require.Equal(t, obj.F8, obj2.F8)
})

// t.Run("write", func(t *testing.T) {
// obj := nbase.NewBaseResp()
// obj.F1 = map[nbase.Str]nbase.Str{"a": "b"}
// obj.F8 = map[float64][]nbase.Str{1.0: []nbase.Str{"a"}}
// obj.Set_FieldMask(fm)
// buf := thrift.NewTMemoryBufferLen(1024)
// prot := thrift.NewTBinaryProtocol(buf, true, true)
// if err := obj.Write(prot); err != nil {
// t.Fatal(err)
// }
// obj2 := nbase.NewBaseResp()
// if err := obj2.Read(prot); err != nil {
// t.Fatal(err)
// }
// fmt.Printf("%#v\n", obj2)
// })
t.Run("write", func(t *testing.T) {
obj := nbase.NewBaseResp()
obj.F1 = map[nbase.Str]nbase.Str{"a": "b"}
obj.F8 = map[float64][]nbase.Str{1.0: []nbase.Str{"a"}}
obj.Set_FieldMask(fm)
buf := thrift.NewTMemoryBufferLen(1024)
prot := thrift.NewTBinaryProtocol(buf, true, true)
if err := obj.Write(prot); err != nil {
t.Fatal(err)
}
// data := []byte(buf.String())
// v, err := dg.NewNode(dt.STRUCT, data).Interface(&dg.Options{})
// if err != nil {
// t.Fatal(err)
// }
// spew.Dump(v)

obj2 := nbase.NewBaseResp()
if err := obj2.Read(prot); err != nil {
t.Fatal(err)
}
fmt.Printf("%#v\n", obj2)
})

}

Expand Down

0 comments on commit 9f7a28b

Please sign in to comment.