Skip to content

Commit

Permalink
datastore: DRY up loading entity code
Browse files Browse the repository at this point in the history
Change-Id: I0057a5b5a0dabb4f0f85a39de1fd234fee9655ea
Reviewed-on: https://code-review.googlesource.com/11352
Reviewed-by: kokoro <[email protected]>
Reviewed-by: Jonathan Amsterdam <[email protected]>
  • Loading branch information
Sarah Adams committed Mar 9, 2017
1 parent df9740f commit 7bcba8a
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 40 deletions.
2 changes: 1 addition & 1 deletion datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ func (c *Client) get(ctx context.Context, keys []*Key, dst interface{}, opts *pb
if multiArgType == multiArgTypeStructPtr && elem.IsNil() {
elem.Set(reflect.New(elem.Type().Elem()))
}
if err := loadEntity(elem.Interface(), e.Entity); err != nil {
if err := loadEntityProto(elem.Interface(), e.Entity); err != nil {
multiErr[index] = err
any = true
}
Expand Down
4 changes: 2 additions & 2 deletions datastore/datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,7 @@ func TestRoundTrip(t *testing.T) {
} else {
got = reflect.New(reflect.TypeOf(tc.want).Elem()).Interface()
}
err = loadEntity(got, p)
err = loadEntityProto(got, p)
if s := checkErr(tc.getErr, err); s != "" {
t.Errorf("%s: load: %s", tc.desc, s)
continue
Expand Down Expand Up @@ -2111,7 +2111,7 @@ func TestLoadSaveNestedStructPLS(t *testing.T) {
}

gota := reflect.New(reflect.TypeOf(tc.wantLoad).Elem()).Interface()
err = loadEntity(gota, e)
err = loadEntityProto(gota, e)
switch tc.loadErr {
case "":
if err != nil {
Expand Down
41 changes: 12 additions & 29 deletions datastore/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,11 @@ func setVal(v reflect.Value, p Property) string {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
// Check if v implements PropertyLoadSaver.
if pls, ok := v.Interface().(PropertyLoadSaver); ok {
err := pls.Load(x.Properties)
if err != nil {
return err.Error()
}
return ""
}
if v.Type().Elem().Kind() != reflect.Struct {
return typeMismatchReason(p, v)
err := loadEntity(v.Interface(), x)
if err != nil {
return err.Error()
}

return setVal(v.Elem(), p)
default:
return typeMismatchReason(p, v)
}
Expand Down Expand Up @@ -281,20 +273,7 @@ func setVal(v reflect.Value, p Property) string {
return fmt.Sprintf("datastore: PropertyLoadSaver methods must be implemented on a pointer to %T.", v.Interface())
}

// Recursively load nested struct.
pls, err := newStructPLS(v.Addr().Interface())
if err != nil {
return err.Error()
}

// if ent has a Key value and our struct has a Key field,
// load the Entity's Key value into the Key field on the struct.
keyField := pls.codec.Match(keyFieldName)
if keyField != nil && ent.Key != nil {
pls.v.FieldByIndex(keyField.Index).Set(reflect.ValueOf(ent.Key))
}

err = pls.Load(ent.Properties)
err := loadEntity(v.Addr().Interface(), ent)
if err != nil {
return err.Error()
}
Expand Down Expand Up @@ -330,14 +309,18 @@ func initField(val reflect.Value, index []int) reflect.Value {
return val.Field(index[len(index)-1])
}

// loadEntity loads an EntityProto into PropertyLoadSaver or struct pointer.
func loadEntity(dst interface{}, src *pb.Entity) (err error) {
// loadEntityProto loads an EntityProto into PropertyLoadSaver or struct pointer.
func loadEntityProto(dst interface{}, src *pb.Entity) error {
ent, err := protoToEntity(src)
if err != nil {
return err
}
if e, ok := dst.(PropertyLoadSaver); ok {
return e.Load(ent.Properties)
return loadEntity(dst, ent)
}

func loadEntity(dst interface{}, ent *Entity) error {
if pls, ok := dst.(PropertyLoadSaver); ok {
return pls.Load(ent.Properties)
}
return loadEntityToStruct(dst, ent)
}
Expand Down
12 changes: 6 additions & 6 deletions datastore/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ func TestLoadEntityNestedLegacy(t *testing.T) {

for _, tc := range testCases {
dst := reflect.New(reflect.TypeOf(tc.want).Elem()).Interface()
err := loadEntity(dst, tc.src)
err := loadEntityProto(dst, tc.src)
if err != nil {
t.Errorf("loadEntity: %s: %v", tc.desc, err)
t.Errorf("loadEntityProto: %s: %v", tc.desc, err)
continue
}

Expand Down Expand Up @@ -401,9 +401,9 @@ func TestLoadEntityNested(t *testing.T) {

for _, tc := range testCases {
dst := reflect.New(reflect.TypeOf(tc.want).Elem()).Interface()
err := loadEntity(dst, tc.src)
err := loadEntityProto(dst, tc.src)
if err != nil {
t.Errorf("loadEntity: %s: %v", tc.desc, err)
t.Errorf("loadEntityProto: %s: %v", tc.desc, err)
continue
}

Expand Down Expand Up @@ -497,9 +497,9 @@ func TestAlreadyPopulatedDst(t *testing.T) {
}

for _, tc := range testCases {
err := loadEntity(tc.dst, tc.src)
err := loadEntityProto(tc.dst, tc.src)
if err != nil {
t.Errorf("loadEntity: %s: %v", tc.desc, err)
t.Errorf("loadEntityProto: %s: %v", tc.desc, err)
continue
}

Expand Down
4 changes: 2 additions & 2 deletions datastore/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ func (c *Client) GetAll(ctx context.Context, q *Query, dst interface{}) ([]*Key,
x := reflect.MakeMap(elemType)
ev.Elem().Set(x)
}
if err = loadEntity(ev.Interface(), e); err != nil {
if err = loadEntityProto(ev.Interface(), e); err != nil {
if _, ok := err.(*ErrFieldMismatch); ok {
// We continue loading entities even in the face of field mismatch errors.
// If we encounter any other error, that other error is returned. Otherwise,
Expand Down Expand Up @@ -628,7 +628,7 @@ func (t *Iterator) Next(dst interface{}) (*Key, error) {
return nil, err
}
if dst != nil && !t.keysOnly {
err = loadEntity(dst, e)
err = loadEntityProto(dst, e)
}
return k, err
}
Expand Down

0 comments on commit 7bcba8a

Please sign in to comment.