-
Notifications
You must be signed in to change notification settings - Fork 0
/
deepcopy.go
213 lines (196 loc) · 6.38 KB
/
deepcopy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
// Package deepcopy makes deep copies of things. A standard copy will copy the
// pointers: deep copy copies the values pointed to. Unexported field
// values are not copied.
package deepcopy
import (
"fmt"
"reflect"
"time"
"unsafe"
)
const (
// startDetectingCyclesAfter is used to check circular reference once the counter exceeds it.
startDetectingCyclesAfter = 1000
// maxReferenceChainLength is used to avoid fatal error stack overflow if the reference chain is too long.
maxReferenceChainLength = 1500
)
// Interface for delegating copy process to type
type Interface interface {
DeepCopy() interface{}
}
// Copy returns a deep copy value of the original value.
// All the values are expected to be deep copied except some cases:
// 1. all unexposed fields in a struct won't be deep copied
// 2. channels are shared when deep copy
// 3. functions are shared when deep copy, beware of variables captured by your function when you do deep copy.
// 4. Copy will return errors for circular references and too long reference chains.
func Copy(src interface{}) (interface{}, error) {
if src == nil {
return nil, nil
}
// Make the interface a reflect.Value
original := reflect.ValueOf(src)
// Make a copy of the same type as the original.
cpy := reflect.New(original.Type()).Elem()
// Recursively copy the original.
err := copyRecursive(original, cpy, &callState{
ptrLevel: 0,
ptrSeen: make(map[interface{}]struct{})})
if err != nil {
return nil, err
}
// Return the copy as an interface.
return cpy.Interface(), nil
}
func copyRecursive(original, cpy reflect.Value, state *callState) error {
state.ptrLevel++
defer func() {
state.ptrLevel--
}()
if int(state.ptrLevel) > maxReferenceChainLength {
return fmt.Errorf("excessive reference chain happened via %s", original.Type().String())
}
// check for implement Interface
if original.CanInterface() {
if copier, ok := original.Interface().(Interface); ok {
cpy.Set(reflect.ValueOf(copier.DeepCopy()))
return nil
}
}
// handle according to original's Kind
switch original.Kind() {
case reflect.Ptr:
ptr := original.Interface()
// the condition is to eliminate cost for common cases. when circular reference,
// the ptrLevel increases extremely fast and then only a little memory is needed
// to be paid for checking.
if state.ptrLevel > uint(startDetectingCyclesAfter) {
if _, ok := state.ptrSeen[ptr]; ok {
return fmt.Errorf("encountered a circular reference via %s", original.Type().String())
}
state.ptrSeen[ptr] = struct{}{}
defer delete(state.ptrSeen, ptr)
}
// Get the actual value being pointed to.
originalValue := original.Elem()
// if it isn't valid, return.
if !originalValue.IsValid() {
return nil
}
cpy.Set(reflect.New(originalValue.Type()))
err := copyRecursive(originalValue, cpy.Elem(), state)
if err != nil {
return err
}
case reflect.Interface:
// If this is a nil, don't do anything
if original.IsNil() {
return nil
}
// Get the value for the interface, not the pointer.
originalValue := original.Elem()
// Get the value by calling Elem().
copyValue := reflect.New(originalValue.Type()).Elem()
err := copyRecursive(originalValue, copyValue, state)
if err != nil {
return err
}
cpy.Set(copyValue)
case reflect.Struct:
t, ok := original.Interface().(time.Time)
if ok {
cpy.Set(reflect.ValueOf(t))
return nil
}
// Go through each field of the struct and copy it.
for i := 0; i < original.NumField(); i++ {
// The Type's StructField for a given field is checked to see if StructField.PkgPath
// is set to determine if the field is exported or not because CanSet() returns false
// for settable fields. I'm not sure why. -mohae
if original.Type().Field(i).PkgPath != "" {
continue
}
err := copyRecursive(original.Field(i), cpy.Field(i), state)
if err != nil {
return err
}
}
case reflect.Slice:
if state.ptrLevel > uint(startDetectingCyclesAfter) {
// > A uintptr is an integer, not a reference. Converting a pointer
// > to a uintptr creates an integer value with no pointer semantics.
// > Even if a uintptr holds the address of some object, the garbage
// > collector will not update that uintptr's value if the object
// > moves, nor will that uintptr keep the object from being reclaimed
//
// Use unsafe.Pointer instead of uintptr because the runtime may
// change its value when object is moved.
//
// The length is stored to distinguish the slice has been seen before
// correctly to avoid cases like right fold a slice.
ptr := struct {
ptr unsafe.Pointer
length int
}{unsafe.Pointer(original.Pointer()), original.Len()}
if _, ok := state.ptrSeen[ptr]; ok {
return fmt.Errorf("encountered a circular reference via %s", original.Type().String())
}
state.ptrSeen[ptr] = struct{}{}
defer delete(state.ptrSeen, ptr)
}
if original.IsNil() {
return nil
}
// Make a new slice and copy each element.
cpy.Set(reflect.MakeSlice(original.Type(), original.Len(), original.Cap()))
for i := 0; i < original.Len(); i++ {
err := copyRecursive(original.Index(i), cpy.Index(i), state)
if err != nil {
return err
}
}
case reflect.Array:
// since origin is an array, the capacity of array will be conserved
cpy.Set(reflect.New(original.Type()).Elem())
for i := 0; i < original.Len(); i++ {
err := copyRecursive(original.Index(i), cpy.Index(i), state)
if err != nil {
return err
}
}
case reflect.Map:
ptr := unsafe.Pointer(original.Pointer())
if state.ptrLevel > uint(startDetectingCyclesAfter) {
if _, ok := state.ptrSeen[ptr]; ok {
return fmt.Errorf("encountered a circular reference via %s", original.Type().String())
}
state.ptrSeen[ptr] = struct{}{}
defer delete(state.ptrSeen, ptr)
}
if original.IsNil() {
return nil
}
cpy.Set(reflect.MakeMap(original.Type()))
for _, key := range original.MapKeys() {
originalValue := original.MapIndex(key)
copyValue := reflect.New(originalValue.Type()).Elem()
err := copyRecursive(originalValue, copyValue, state)
if err != nil {
return err
}
copiedKey := reflect.New(key.Type()).Elem()
err = copyRecursive(key, copiedKey, state)
if err != nil {
return err
}
cpy.SetMapIndex(copiedKey, copyValue)
}
default:
cpy.Set(original)
}
return nil
}
type callState struct {
ptrLevel uint
ptrSeen map[interface{}]struct{}
}