diff --git a/errors.go b/errors.go index bab9ef2..5e80c50 100644 --- a/errors.go +++ b/errors.go @@ -18,4 +18,6 @@ var ( // ErrFieldRequireCopying returned when a field is required to be copied // but no copying is done for it. ErrFieldRequireCopying = errors.New("ErrFieldRequireCopying") + // ErrMethodInvalid returned when copying method of a struct is not valid + ErrMethodInvalid = errors.New("ErrMethodInvalid") ) diff --git a/struct_copier.go b/struct_copier.go index f07ff88..3e37373 100644 --- a/struct_copier.go +++ b/struct_copier.go @@ -27,92 +27,85 @@ func (c *structCopier) Copy(dst, src reflect.Value) error { return nil } -// nolint: gocognit, gocyclo +//nolint:gocognit,gocyclo func (c *structCopier) init(dstType, srcType reflect.Type) (err error) { cacheKey := c.ctx.createCacheKey(dstType, srcType) c.ctx.mu.RLock() cp, ok := c.ctx.copierCacheMap[*cacheKey] c.ctx.mu.RUnlock() if ok { - c.fieldCopiers = cp.(*structCopier).fieldCopiers // nolint: forcetypeassert + c.fieldCopiers = cp.(*structCopier).fieldCopiers //nolint:forcetypeassert return nil } - srcFields := srcType.NumField() - dstFields := dstType.NumField() - mapDstField := make(map[string]*fieldDetail, dstFields) - for i := 0; i < dstFields; i++ { - df := dstType.Field(i) - dfDetail := &fieldDetail{field: &df} - parseTag(dfDetail) - if dfDetail.ignored { - continue - } - mapDstField[dfDetail.key] = dfDetail + var dstCopyingMethods map[string]*reflect.Method + if c.ctx.CopyBetweenStructFieldAndMethod { + dstCopyingMethods = c.parseCopyingMethods(dstType) } - c.fieldCopiers = make([]copier, 0, dstFields) - for i := 0; i < srcFields; i++ { - sf := srcType.Field(i) - sfDetail := &fieldDetail{field: &sf} - parseTag(sfDetail) - if sfDetail.ignored { + dstDirectFields, mapDstDirectFields, dstInheritedFields, mapDstInheritedFields := c.parseAllFields(dstType) + srcDirectFields, mapSrcDirectFields, srcInheritedFields, mapSrcInheritedFields := c.parseAllFields(srcType) + c.fieldCopiers = make([]copier, 0, len(dstDirectFields)+len(dstInheritedFields)) + + for _, key := range append(srcDirectFields, srcInheritedFields...) { + // Find field details from `src` having the key + sfDetail := mapSrcDirectFields[key] + if sfDetail == nil { + sfDetail = mapSrcInheritedFields[key] + } + if sfDetail == nil || sfDetail.ignored || sfDetail.done { continue } - dfDetail := mapDstField[sfDetail.key] - if dfDetail == nil { - // Find method with same name as source field - dMethod := c.findDstMethod(dstType, sfDetail) - if dMethod == nil { - if sfDetail.required { - return fmt.Errorf("%w: struct field %s[%s] requires copying", - ErrFieldRequireCopying, srcType.Name(), sfDetail.field.Name) - } + // Copying methods have higher priority, so if a method defined in the dst struct, use it + if dstCopyingMethods != nil { + methodName := "Copy" + strings.ToUpper(key[:1]) + key[1:] + dstCpMethod, exists := dstCopyingMethods[methodName] + if exists && !dstCpMethod.Type.In(1).AssignableTo(sfDetail.field.Type) { + return fmt.Errorf("%w: struct method '%v.%s' does not accept argument type '%v' from '%v[%s]'", + ErrMethodInvalid, dstType, dstCpMethod.Name, sfDetail.field.Type, srcType, sfDetail.field.Name) + } + if exists { + c.fieldCopiers = append(c.fieldCopiers, c.createField2MethodCopier(dstCpMethod, sfDetail)) + sfDetail.markDone() continue } - c.fieldCopiers = append(c.fieldCopiers, c.createMethodCopier(dMethod, &sf)) - continue } - // Destination field found and matched - df := dfDetail.field - delete(mapDstField, sfDetail.key) - - // OPTIMIZATION: buildCopier() can handle this nicely, but it will add another wrapping layer - if simpleKindMask&(1< 0 { - if sf.Type == df.Type { - c.fieldCopiers = append(c.fieldCopiers, c.createDirectCopier(df, &sf)) - continue - } - if sf.Type.ConvertibleTo(df.Type) { - c.fieldCopiers = append(c.fieldCopiers, c.createConvCopier(df, &sf)) - continue + // Find field details from `dst` having the key + dfDetail := mapDstDirectFields[key] + if dfDetail == nil { + dfDetail = mapDstInheritedFields[key] + } + if dfDetail == nil || dfDetail.ignored || dfDetail.done { + // Found no corresponding dest field to copy to, raise an error in case this is required + if sfDetail.required { + return fmt.Errorf("%w: struct field '%v[%s]' requires copying", + ErrFieldRequireCopying, srcType, sfDetail.field.Name) } + continue } - cp, err := buildCopier(c.ctx, df.Type, sf.Type) + copier, err := c.buildCopier(dstType, srcType, dfDetail, sfDetail) if err != nil { return err } - if c.ctx.IgnoreNonCopyableTypes && (sfDetail.required || dfDetail.required) { - _, isNopCopier := cp.(*nopCopier) - if isNopCopier && dfDetail.required { - return fmt.Errorf("%w: struct field %s[%s] requires copying", - ErrFieldRequireCopying, dstType.Name(), dfDetail.field.Name) - } - if isNopCopier && sfDetail.required { - return fmt.Errorf("%w: struct field %s[%s] requires copying", - ErrFieldRequireCopying, srcType.Name(), sfDetail.field.Name) - } - } - c.fieldCopiers = append(c.fieldCopiers, c.createCustomCopier(df, &sf, cp)) + c.fieldCopiers = append(c.fieldCopiers, copier) + dfDetail.markDone() + sfDetail.markDone() } - for _, dfDetail := range mapDstField { - if dfDetail.required { - return fmt.Errorf("%w: struct field %s[%s] requires copying", - ErrFieldRequireCopying, dstType.Name(), dfDetail.field.Name) + // Remaining dst fields can't be copied + for _, dfDetail := range mapDstDirectFields { + if !dfDetail.done && dfDetail.required { + return fmt.Errorf("%w: struct field '%v[%s]' requires copying", + ErrFieldRequireCopying, dstType, dfDetail.field.Name) + } + } + for _, dfDetail := range mapDstInheritedFields { + if !dfDetail.done && dfDetail.required { + return fmt.Errorf("%w: struct field '%v[%s]' requires copying", + ErrFieldRequireCopying, dstType, dfDetail.field.Name) } } @@ -122,137 +115,243 @@ func (c *structCopier) init(dstType, srcType reflect.Type) (err error) { return nil } -func (c *structCopier) findDstMethod(dstType reflect.Type, sfDetail *fieldDetail) *reflect.Method { - if !c.ctx.CopyBetweenStructFieldAndMethod { - return nil - } - // Find method with name is 'Copy' + source field - // (e.g. src field is 'Amount', dst method should be CopyAmount) - methodName := "Copy" + strings.ToUpper(sfDetail.key[:1]) + sfDetail.key[1:] - dMethod, found := reflect.PointerTo(dstType).MethodByName(methodName) - if !found { - return nil +// parseCopyingMethods collects all copying methods from the given struct type +func (c *structCopier) parseCopyingMethods(structType reflect.Type) map[string]*reflect.Method { + ptrType := reflect.PointerTo(structType) + numMethods := ptrType.NumMethod() + result := make(map[string]*reflect.Method, numMethods) + for i := 0; i < numMethods; i++ { + method := ptrType.Method(i) + // Method name must be something like `Copy` + if !strings.HasPrefix(method.Name, "Copy") { + continue + } + // Method must accept an arg and return error type (1st arg is the struct itself) + if method.Type.NumIn() != 2 || method.Type.NumOut() != 1 { + continue + } + if method.Type.Out(0) != errType { + continue + } + result[method.Name] = &method } - if dMethod.Type.NumIn() != 2 || dMethod.Type.NumOut() != 1 { - return nil + return result +} + +// parseAllFields parses all fields of a struct including direct fields and fields inherited from embedded structs +func (c *structCopier) parseAllFields(typ reflect.Type) ( + directFieldKeys []string, + mapDirectFields map[string]*fieldDetail, + inheritedFieldKeys []string, + mapInheritedFields map[string]*fieldDetail, +) { + numFields := typ.NumField() + directFieldKeys = make([]string, 0, numFields) + mapDirectFields = make(map[string]*fieldDetail, numFields) + inheritedFieldKeys = make([]string, 0, numFields) + mapInheritedFields = make(map[string]*fieldDetail, numFields) + + for i := 0; i < numFields; i++ { + sf := typ.Field(i) + fDetail := &fieldDetail{field: &sf, index: []int{i}} + parseTag(fDetail) + if fDetail.ignored { + continue + } + directFieldKeys = append(directFieldKeys, fDetail.key) + mapDirectFields[fDetail.key] = fDetail + + // Parse embedded struct to get its fields + if sf.Anonymous { + for key, detail := range c.parseAllNestedFields(sf.Type, fDetail.index) { + inheritedFieldKeys = append(inheritedFieldKeys, key) + mapInheritedFields[key] = detail + fDetail.nestedFields = append(fDetail.nestedFields, detail) + } + } } - if !dMethod.Type.In(1).AssignableTo(sfDetail.field.Type) { - return nil + return directFieldKeys, mapDirectFields, inheritedFieldKeys, mapInheritedFields +} + +// parseAllNestedFields parses all fields with initial index of starting field +func (c *structCopier) parseAllNestedFields(typ reflect.Type, index []int) map[string]*fieldDetail { + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() } - if dMethod.Type.Out(0) != errType { + if typ.Kind() != reflect.Struct { return nil } - return &dMethod -} + numFields := typ.NumField() + result := make(map[string]*fieldDetail, numFields) -func (c *structCopier) createDirectCopier(df, sf *reflect.StructField) copier { - if df.IsExported() && sf.IsExported() { - return &structFieldDirectCopier{ - dstField: df.Index[0], - srcField: sf.Index[0], + for i := 0; i < numFields; i++ { + sf := typ.Field(i) + fDetail := &fieldDetail{field: &sf, index: append(index, i)} + parseTag(fDetail) + if fDetail.ignored { + continue + } + result[fDetail.key] = fDetail + // Parse embedded struct recursively to get its fields + if sf.Anonymous { + for key, detail := range c.parseAllNestedFields(sf.Type, fDetail.index) { + result[key] = detail + fDetail.nestedFields = append(fDetail.nestedFields, detail) + } } } - return &structUnexportedFieldCopier{ - copier: &directCopier{}, - dstField: df.Index[0], - dstFieldUnexported: !df.IsExported(), - srcField: sf.Index[0], - srcFieldUnexported: !sf.IsExported(), - } + return result } -func (c *structCopier) createConvCopier(df, sf *reflect.StructField) copier { - if df.IsExported() && sf.IsExported() { - return &structFieldConvCopier{ - dstField: df.Index[0], - srcField: sf.Index[0], +func (c *structCopier) buildCopier(dstType, srcType reflect.Type, dstDetail, srcDetail *fieldDetail) (copier, error) { + df, sf := dstDetail.field, srcDetail.field + + // OPTIMIZATION: buildCopier() can handle this nicely, but it will add another wrapping layer + if simpleKindMask&(1< 0 { + if sf.Type == df.Type { + // NOTE: pass nil to unset custom copier and trigger direct copying. + // We can pass `&directCopier{}` for the same result (but it's a bit slower). + return c.createField2FieldCopier(dstDetail, srcDetail, nil), nil } + if sf.Type.ConvertibleTo(df.Type) { + return c.createField2FieldCopier(dstDetail, srcDetail, &convCopier{}), nil + } + } + + cp, err := buildCopier(c.ctx, df.Type, sf.Type) + if err != nil { + return nil, err } - return &structUnexportedFieldCopier{ - copier: &convCopier{}, - dstField: df.Index[0], - dstFieldUnexported: !df.IsExported(), - srcField: sf.Index[0], - srcFieldUnexported: !sf.IsExported(), + if c.ctx.IgnoreNonCopyableTypes && (srcDetail.required || dstDetail.required) { + _, isNopCopier := cp.(*nopCopier) + if isNopCopier && dstDetail.required { + return nil, fmt.Errorf("%w: struct field '%v[%s]' requires copying", + ErrFieldRequireCopying, dstType, dstDetail.field.Name) + } + if isNopCopier && srcDetail.required { + return nil, fmt.Errorf("%w: struct field '%v[%s]' requires copying", + ErrFieldRequireCopying, srcType, srcDetail.field.Name) + } } + return c.createField2FieldCopier(dstDetail, srcDetail, cp), nil } -func (c *structCopier) createMethodCopier(dM *reflect.Method, sf *reflect.StructField) copier { - return &structFieldMethodCopier{ +func (c *structCopier) createField2MethodCopier(dM *reflect.Method, sfDetail *fieldDetail) copier { + return &structField2MethodCopier{ dstMethod: dM.Index, dstMethodUnexported: !dM.IsExported(), - srcField: sf.Index[0], - srcFieldUnexported: !sf.IsExported(), + srcFieldIndex: sfDetail.index, + srcFieldUnexported: !sfDetail.field.IsExported(), } } -func (c *structCopier) createCustomCopier(df, sf *reflect.StructField, cp copier) copier { - if df.IsExported() && sf.IsExported() { - return &structFieldCopier{ - copier: cp, - dstField: df.Index[0], - srcField: sf.Index[0], - } - } - return &structUnexportedFieldCopier{ +func (c *structCopier) createField2FieldCopier(df, sf *fieldDetail, cp copier) copier { + return &structField2FieldCopier{ copier: cp, - dstField: df.Index[0], - dstFieldUnexported: !df.IsExported(), - srcField: sf.Index[0], - srcFieldUnexported: !sf.IsExported(), + dstFieldIndex: df.index, + dstFieldUnexported: !df.field.IsExported(), + srcFieldIndex: sf.index, + srcFieldUnexported: !sf.field.IsExported(), } } // structFieldDirectCopier data structure of copier that copies from // a src field to a dst field directly -type structFieldDirectCopier struct { - dstField int - srcField int +type structField2FieldCopier struct { + copier copier + dstFieldIndex []int + dstFieldUnexported bool + srcFieldIndex []int + srcFieldUnexported bool } // Copy implementation of Copy function for struct field copier direct -func (c *structFieldDirectCopier) Copy(dst, src reflect.Value) error { - dst.Field(c.dstField).Set(src.Field(c.srcField)) - return nil -} +func (c *structField2FieldCopier) Copy(dst, src reflect.Value) (err error) { + if len(c.srcFieldIndex) == 1 { + src = src.Field(c.srcFieldIndex[0]) + } else { + // NOTE: When a struct pointer is embedded (e.g. type StructX struct { *BaseStruct }), + // this retrieval can fail if the embedded struct pointer is nil. Just skip copying when fails. + src, err = src.FieldByIndexErr(c.srcFieldIndex) + if err != nil { + // There's no src field to copy from, reset the dst field to zero + c.setFieldZero(dst, c.dstFieldIndex) + return nil //nolint:nilerr + } + } + if c.srcFieldUnexported { + if !src.CanAddr() { + return fmt.Errorf("%w: accessing unexported field requires it to be addressable", + ErrValueUnaddressable) + } + src = reflect.NewAt(src.Type(), unsafe.Pointer(src.UnsafeAddr())).Elem() //nolint:gosec + } -// structFieldConvCopier data structure of copier that copies from -// a src field to a dst field with type conversion -type structFieldConvCopier struct { - dstField int - srcField int -} + if len(c.dstFieldIndex) == 1 { + dst = dst.Field(c.dstFieldIndex[0]) + } else { + // Get dst field with making sure it's settable + dst = c.getFieldWithInit(dst, c.dstFieldIndex) + } + if c.dstFieldUnexported { + if !dst.CanAddr() { + return fmt.Errorf("%w: accessing unexported field requires it to be addressable", + ErrValueUnaddressable) + } + dst = reflect.NewAt(dst.Type(), unsafe.Pointer(dst.UnsafeAddr())).Elem() //nolint:gosec + } -// Copy implementation of Copy function for struct field copier with type conversion -func (c *structFieldConvCopier) Copy(dst, src reflect.Value) error { - dstVal := dst.Field(c.dstField) - dstVal.Set(src.Field(c.srcField).Convert(dstVal.Type())) + // Use custom copier if set + if c.copier != nil { + return c.copier.Copy(dst, src) + } + // Otherwise, just perform simple direct copying + dst.Set(src) return nil } -// structFieldCopier wrapping copier for copying struct field -type structFieldCopier struct { - copier copier - dstField int - srcField int +// getFieldWithInit gets deep nested field with init value for pointer ones +func (c *structField2FieldCopier) getFieldWithInit(field reflect.Value, index []int) reflect.Value { + for _, idx := range index { + if field.Kind() == reflect.Pointer { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + } + field = field.Field(idx) + } + return field } -// Copy implementation of Copy function for struct field copier -func (c *structFieldCopier) Copy(dst, src reflect.Value) error { - return c.copier.Copy(dst.Field(c.dstField), src.Field(c.srcField)) +// setFieldZero sets zero to a deep nested field +func (c *structField2FieldCopier) setFieldZero(field reflect.Value, index []int) { + field, err := field.FieldByIndexErr(index) + if err == nil && field.IsValid() { + field.Set(reflect.Zero(field.Type())) // NOTE: Go1.18 has no SetZero + } } -// structFieldMethodCopier data structure of copier that copies between `fields` and `methods` -type structFieldMethodCopier struct { +// structField2MethodCopier data structure of copier that copies between `fields` and `methods` +type structField2MethodCopier struct { dstMethod int dstMethodUnexported bool - srcField int + srcFieldIndex []int srcFieldUnexported bool } // Copy implementation of Copy function for struct field copier between `fields` and `methods` -func (c *structFieldMethodCopier) Copy(dst, src reflect.Value) error { - src = src.Field(c.srcField) +func (c *structField2MethodCopier) Copy(dst, src reflect.Value) (err error) { + if len(c.srcFieldIndex) == 1 { + src = src.Field(c.srcFieldIndex[0]) + } else { + // NOTE: When a struct pointer is embedded (e.g. type StructX struct { *BaseStruct }), + // this retrieval can fail if the embedded struct pointer is nil. Just skip copying when fails. + src, err = src.FieldByIndexErr(c.srcFieldIndex) + if err != nil { + return nil //nolint:nilerr + } + } if c.srcFieldUnexported { if !src.CanAddr() { return fmt.Errorf("%w: accessing unexported field requires it to be addressable", @@ -260,10 +359,12 @@ func (c *structFieldMethodCopier) Copy(dst, src reflect.Value) error { } src = reflect.NewAt(src.Type(), unsafe.Pointer(src.UnsafeAddr())).Elem() //nolint:gosec } + dst = dst.Addr().Method(c.dstMethod) if c.dstMethodUnexported { dst = reflect.NewAt(dst.Type(), unsafe.Pointer(dst.UnsafeAddr())).Elem() //nolint:gosec } + errVal := dst.Call([]reflect.Value{src})[0] if errVal.IsNil() { return nil @@ -274,29 +375,3 @@ func (c *structFieldMethodCopier) Copy(dst, src reflect.Value) error { } return err } - -// structUnexportedFieldCopier data structure of copier that copies between unexported fields of struct -type structUnexportedFieldCopier struct { - copier copier - dstField int - dstFieldUnexported bool - srcField int - srcFieldUnexported bool -} - -// Copy implementation of Copy function for struct unexported field copier -func (c *structUnexportedFieldCopier) Copy(dst, src reflect.Value) error { - src = src.Field(c.srcField) - if c.srcFieldUnexported { - if !src.CanAddr() { - return fmt.Errorf("%w: accessing unexported field requires it to be addressable", - ErrValueUnaddressable) - } - src = reflect.NewAt(src.Type(), unsafe.Pointer(src.UnsafeAddr())).Elem() //nolint:gosec - } - dst = dst.Field(c.dstField) - if c.dstFieldUnexported { - dst = reflect.NewAt(dst.Type(), unsafe.Pointer(dst.UnsafeAddr())).Elem() //nolint:gosec - } - return c.copier.Copy(dst, src) -} diff --git a/struct_copier_test.go b/struct_copier_test.go index ca11d9e..7b2a5f0 100644 --- a/struct_copier_test.go +++ b/struct_copier_test.go @@ -404,6 +404,9 @@ func (d *testD1) CopyI5(i5 int) string { // incorrect method prototype (not retu func (d *testD1) CopyI6(i6 int) error { // incorrect method prototype (unmatched input type) return errTest } +func (d *testD1) NotCopy(i6 int) error { // not a copying method + return errTest +} func Test_Copy_struct_method(t *testing.T) { t.Run("#1: field -> dst method", func(t *testing.T) { @@ -457,6 +460,38 @@ func Test_Copy_struct_method(t *testing.T) { assert.Nil(t, err) assert.Equal(t, testD1{U: 2}, d) }) + + t.Run("#5: copy from src embedded field", func(t *testing.T) { + type SBase struct { + I1 int + } + type SS struct { + SBase + U uint + } + + var s SS = SS{U: 2, SBase: SBase{I1: 123}} + var d testD1 + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, testD1{U: 2, x1: 246}, d) + }) + + t.Run("#6: copy from src embedded field, but field value can't be retrieved due to nil ptr", func(t *testing.T) { + type SBase struct { + I1 int + } + type SS struct { + *SBase + U uint + } + + var s SS = SS{U: 2} + var d testD1 + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, testD1{U: 2}, d) + }) } func Test_Copy_struct_method_error(t *testing.T) { @@ -505,7 +540,7 @@ func Test_Copy_struct_method_error(t *testing.T) { var s SS = SS{I4: 1, U: 2} var d testD1 err := Copy(&d, s) - assert.ErrorIs(t, err, ErrFieldRequireCopying) + assert.ErrorIs(t, err, ErrMethodInvalid) }) t.Run("#5: incorrect method prototype (CopyI5())", func(t *testing.T) { @@ -532,3 +567,264 @@ func Test_Copy_struct_method_error(t *testing.T) { assert.ErrorIs(t, err, errTest) }) } + +func Test_Copy_struct_with_embedded_struct(t *testing.T) { + type SBase1 struct { + I int + } + type SBase2 struct { + SBase1 + S string + } + + type DBase1 struct { + I int + } + type DBase2 struct { + DBase1 + S string + } + + t.Run("#1: both src and dst have equivalent embedded fields", func(t *testing.T) { + type SS struct { + SBase2 + U uint `copy:",required"` + } + type DD struct { + DBase2 + U uint `copy:",required"` + } + + s := SS{U: 100, SBase2: SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d := DD{} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, DBase2: DBase2{S: "abc", DBase1: DBase1{I: 11}}}, d) + + // With some tags + type SS2 struct { + SBase2 `copy:"base,required"` + U uint `copy:",required"` + } + type DD2 struct { + DBase2 `copy:"base,required"` + U uint `copy:",required"` + } + + s2 := SS2{U: 100, SBase2: SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d2 := DD2{} + err = Copy(&d2, s2) + assert.Nil(t, err) + assert.Equal(t, DD2{U: 100, DBase2: DBase2{S: "abc", DBase1: DBase1{I: 11}}}, d2) + }) + + t.Run("#2: both src and dst have same embedded struct", func(t *testing.T) { + type SS struct { + SBase2 + U uint `copy:",required"` + } + type DD struct { + SBase2 + U uint `copy:",required"` + } + + s := SS{U: 100, SBase2: SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d := DD{U: 123, SBase2: SBase2{S: "xyz", SBase1: SBase1{I: 111}}} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, SBase2: SBase2{S: "abc", SBase1: SBase1{I: 11}}}, d) + }) + + t.Run("#3: both src and dst have equivalent embedded fields, but src embeds ptr of struct", func(t *testing.T) { + type SS struct { + *SBase2 + U uint `copy:",required"` + } + type DD struct { + DBase2 + U uint `copy:",required"` + } + + // Ptr has value set + s := SS{U: 100, SBase2: &SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d := DD{} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, DBase2: DBase2{S: "abc", DBase1: DBase1{I: 11}}}, d) + + // Ptr is nil + s = SS{U: 100} + d = DD{} + err = Copy(&d, &s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100}, d) + }) + + t.Run("#4: both src and dst have equivalent embedded fields, but dst embeds ptr of struct", func(t *testing.T) { + type SS struct { + SBase2 + U uint + } + type DD struct { + *DBase2 + U uint + } + + s := SS{U: 100, SBase2: SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d := DD{} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, DBase2: &DBase2{S: "abc", DBase1: DBase1{I: 11}}}, d) + }) + + t.Run("#5: src has embedded struct, dst doesn't (flattening the copy)", func(t *testing.T) { + type SS struct { + SBase2 + U uint + } + type DD struct { + I int `copy:",required"` + S string `copy:",required"` + U uint `copy:",required"` + } + + s := SS{U: 100, SBase2: SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d := DD{S: "xyz"} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, S: "abc", I: 11}, d) + + // With ignoring a field + type DD2 struct { + I int `copy:",required"` + S string `copy:"-"` + U uint `copy:",required"` + } + + d2 := DD2{} + err = Copy(&d2, s) + assert.Nil(t, err) + assert.Equal(t, DD2{U: 100, I: 11}, d2) + }) + + t.Run("#6: src has embedded struct ptr, dst doesn't (flattening the copy)", func(t *testing.T) { + type SS struct { + *SBase2 + U uint + } + type DD struct { + I int `copy:",required"` + S string `copy:",required"` + U uint `copy:",required"` + } + + // Ptr has a value set + s := SS{U: 100, SBase2: &SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d := DD{S: "xyz"} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, S: "abc", I: 11}, d) + + // Ptr is nil + s = SS{U: 100} + d = DD{S: "xyz"} + err = Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100}, d) + + // With ignoring a field + type DD2 struct { + I int `copy:",required"` + S string `copy:"-"` + U uint `copy:",required"` + } + + s = SS{U: 100, SBase2: &SBase2{S: "abc", SBase1: SBase1{I: 11}}} + d2 := DD2{S: "xyz"} + err = Copy(&d2, s) + assert.Nil(t, err) + assert.Equal(t, DD2{U: 100, S: "xyz", I: 11}, d2) + }) + + t.Run("#7: dst has embedded struct, src doesn't (flattening the copy)", func(t *testing.T) { + type SS struct { + I int `copy:",required"` + S string `copy:",required"` + U uint `copy:",required"` + } + type DD struct { + DBase2 + U uint + } + + s := SS{U: 100, S: "abc", I: 11} + d := DD{U: 123, DBase2: DBase2{S: "xyz"}} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, DBase2: DBase2{S: "abc", DBase1: DBase1{I: 11}}}, d) + }) + + t.Run("#8: dst has embedded struct ptr, src doesn't (flattening the copy)", func(t *testing.T) { + type SS struct { + I int `copy:",required"` + S string `copy:",required"` + U uint `copy:",required"` + } + type DD struct { + *DBase2 + U uint + } + + // Ptr has value set + s := SS{U: 100, S: "abc", I: 11} + d := DD{U: 123, DBase2: &DBase2{S: "xyz"}} + err := Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, DBase2: &DBase2{S: "abc", DBase1: DBase1{I: 11}}}, d) + + // Ptr is nil initially + s = SS{U: 100, S: "abc", I: 11} + d = DD{U: 123} + err = Copy(&d, s) + assert.Nil(t, err) + assert.Equal(t, DD{U: 100, DBase2: &DBase2{S: "abc", DBase1: DBase1{I: 11}}}, d) + }) +} + +func Test_Copy_struct_with_embedded_struct_error(t *testing.T) { + t.Run("#1: src inherited field requires copying", func(t *testing.T) { + type SBase struct { + I int `copy:",required"` + } + type SS struct { + SBase + U uint + } + type DD struct { + U uint + } + + s := SS{U: 100, SBase: SBase{I: 11}} + d := DD{} + err := Copy(&d, s) + assert.ErrorIs(t, err, ErrFieldRequireCopying) + }) + + t.Run("#2: dst inherited field requires copying", func(t *testing.T) { + type SS struct { + U uint + } + type DBase struct { + I int `copy:",required"` + } + type DD struct { + DBase + U uint + } + + s := SS{U: 100} + d := DD{} + err := Copy(&d, s) + assert.ErrorIs(t, err, ErrFieldRequireCopying) + }) +} diff --git a/tag.go b/struct_tag.go similarity index 66% rename from tag.go rename to struct_tag.go index 46d5ffa..48efaec 100644 --- a/tag.go +++ b/struct_tag.go @@ -11,6 +11,18 @@ type fieldDetail struct { key string ignored bool required bool + + done bool + index []int + nestedFields []*fieldDetail +} + +// markDone sets the `done` flag of a field detail and all of its nested fields recursively +func (detail *fieldDetail) markDone() { + detail.done = true + for _, f := range detail.nestedFields { + f.markDone() + } } // parseTag parses struct tag for getting copying detail and configuration @@ -30,7 +42,7 @@ func parseTag(detail *fieldDetail) { } for _, tagOpt := range tags[1:] { - if tagOpt == "required" { + if tagOpt == "required" && !detail.ignored { detail.required = true } } diff --git a/tag_test.go b/struct_tag_test.go similarity index 100% rename from tag_test.go rename to struct_tag_test.go