Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jsii-go): use reflect to resolve overriden methods #2780

Merged
merged 9 commits into from
Apr 9, 2021
31 changes: 15 additions & 16 deletions packages/@jsii/go-runtime-test/project/compliance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (suite *ComplianceSuite) TestEqualsIsResistantToPropertyShadowingResultVari
}

type overridableProtectedMemberDerived struct {
calc.OverridableProtectedMember `overrides:"OverrideReadOnly,OverrideReadWrite"`
calc.OverridableProtectedMember
}

func newOverridableProtectedMemberDerived() *overridableProtectedMemberDerived {
Expand Down Expand Up @@ -381,7 +381,7 @@ func newImplementsAdditionalInterface(s calc.StructB) *implementsAdditionalInter
return &n
}

func (x implementsAdditionalInterface) ReturnStruct() *calc.StructB {
func (x *implementsAdditionalInterface) ReturnStruct() *calc.StructB {
return &x._struct
}

Expand Down Expand Up @@ -500,7 +500,7 @@ func (suite *ComplianceSuite) TestCallbacksCorrectlyDeserializeArguments() {
}

type testCallbacksCorrectlyDeserializeArgumentsDataRenderer struct {
calc.DataRenderer `overrides:"RenderMap"`
calc.DataRenderer
}

func NewTestCallbacksCorrectlyDeserializeArgumentsDataRenderer() *testCallbacksCorrectlyDeserializeArgumentsDataRenderer {
Expand Down Expand Up @@ -610,7 +610,7 @@ func (suite *ComplianceSuite) TestNullShouldBeTreatedAsUndefined() {
}

type myOverridableProtectedMember struct {
calc.OverridableProtectedMember `overrides:"OverrideMe"`
calc.OverridableProtectedMember
}

func newMyOverridableProtectedMember() *myOverridableProtectedMember {
Expand All @@ -619,7 +619,7 @@ func newMyOverridableProtectedMember() *myOverridableProtectedMember {
return &m
}

func (x myOverridableProtectedMember) OverrideMe() *string {
func (x *myOverridableProtectedMember) OverrideMe() *string {
return jsii.String("Cthulhu Fhtagn!")
}

Expand Down Expand Up @@ -685,7 +685,7 @@ func (suite *ComplianceSuite) TestMapReturnedByMethodCanBeRead() {
}

type myAbstractSuite struct {
calc.AbstractSuite `overrides:"SomeMethod,Property"`
calc.AbstractSuite

_property *string
}
Expand Down Expand Up @@ -724,10 +724,10 @@ func (suite *ComplianceSuite) TestCanOverrideProtectedSetter() {
}

type TestCanOverrideProtectedSetterOverridableProtectedMember struct {
calc.OverridableProtectedMember `overrides:"OverrideReadWrite"`
calc.OverridableProtectedMember
}

func (x TestCanOverrideProtectedSetterOverridableProtectedMember) SetOverrideReadWrite(val *string) {
func (x *TestCanOverrideProtectedSetterOverridableProtectedMember) SetOverrideReadWrite(val *string) {
x.OverridableProtectedMember.SetOverrideReadWrite(jsii.String(fmt.Sprintf("zzzzzzzzz%s", *val)))
}

Expand All @@ -737,7 +737,6 @@ func newTestCanOverrideProtectedSetterOverridableProtectedMember() *TestCanOverr
return &m
}


func (suite *ComplianceSuite) TestObjRefsAreLabelledUsingWithTheMostCorrectType() {
require := suite.Require()

Expand Down Expand Up @@ -952,7 +951,7 @@ func (suite *ComplianceSuite) TestPropertyOverrides_Get_Calls_Super() {
}

type testPropertyOverridesGetCallsSuper struct {
calc.SyncVirtualMethods `overrides:"TheProperty"`
calc.SyncVirtualMethods
}

func (t *testPropertyOverridesGetCallsSuper) TheProperty() *string {
Expand Down Expand Up @@ -1063,7 +1062,7 @@ func (suite *ComplianceSuite) TestPropertyOverrides_Get_Throws() {
}

type testPropertyOverridesGetThrows struct {
calc.SyncVirtualMethods `overrides:"TheProperty"`
calc.SyncVirtualMethods
}

func (t *testPropertyOverridesGetThrows) TheProperty() *string {
Expand Down Expand Up @@ -1143,7 +1142,7 @@ func (suite *ComplianceSuite) TestPropertyOverrides_Set_Calls_Super() {
}

type testPropertyOverridesSetCallsSuper struct {
calc.SyncVirtualMethods `overrides:"TheProperty"`
calc.SyncVirtualMethods
}

func (t *testPropertyOverridesSetCallsSuper) SetTheProperty(value *string) {
Expand Down Expand Up @@ -1327,8 +1326,8 @@ func (suite *ComplianceSuite) TestObjectIdDoesNotGetReallocatedWhenTheConstructo
}

type partiallyInitializedThisConsumerImpl struct {
calc.PartiallyInitializedThisConsumer `overrides:"ConsumePartiallyInitializedThis"`
require *require.Assertions
calc.PartiallyInitializedThisConsumer
require *require.Assertions
}

func NewPartiallyInitializedThisConsumerImpl(assert *require.Assertions) *partiallyInitializedThisConsumerImpl {
Expand Down Expand Up @@ -1478,7 +1477,7 @@ func (suite *ComplianceSuite) TestPropertyOverrides_Set_Throws() {
}

type testPropertyOverrides_Set_ThrowsSyncVirtualMethods struct {
calc.SyncVirtualMethods `overrides:"TheProperty"`
calc.SyncVirtualMethods
}

func NewTestPropertyOverrides_Set_ThrowsSyncVirtualMethods() *testPropertyOverrides_Set_ThrowsSyncVirtualMethods {
Expand Down Expand Up @@ -1582,7 +1581,7 @@ func (suite *ComplianceSuite) TestAsyncOverrides_OverrideCallsSuper() {
}

type OverrideCallsSuper struct {
calc.AsyncVirtualMethods `overrides:"OverrideMe"`
calc.AsyncVirtualMethods
}

func (o *OverrideCallsSuper) OverrideMe(mult *float64) *float64 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type OverrideAsyncMethods struct {
jsiicalc.AsyncVirtualMethods `overrides:"OverrideMe"`
jsiicalc.AsyncVirtualMethods
}

func New() *OverrideAsyncMethods {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
)

type SyncOverrides struct {
jsiicalc.SyncVirtualMethods `overrides:"VirtualMethod,TheProperty"`
AnotherTheProperty *string
Multiplier int
ReturnSuper bool
CallAsync bool
jsiicalc.SyncVirtualMethods
AnotherTheProperty *string
Multiplier int
ReturnSuper bool
CallAsync bool
}

func New() *SyncOverrides {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type TwoOverrides struct {
jsiicalc.AsyncVirtualMethods `overrides:"OverrideMe,OverrideMeToo"`
jsiicalc.AsyncVirtualMethods
}

func New() *TwoOverrides {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (w *WallClock) Iso8601Now() *string {
}

type Entropy struct {
jsiicalc.Entropy `overrides:"Repeat"`
jsiicalc.Entropy
}

func NewEntropy(clock jsiicalc.IWallClock) *Entropy {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package jsii

import (
"reflect"
"sort"
"testing"
)

type IFace interface {
M1()
M2()
M3()
}

type Base struct {
X, Y int
}

func (b *Base) M1() {}
func (b *Base) M2() {}
func (b *Base) M3() {}

type D1 struct {
*Base
}

func (d *D1) M1() {}

func (d *D1) X1() {}

type D2 struct {
Name string
IFace
}

func (d *D2) M2() {}

func (d *D2) X2() {}

func TestOverrideReflection(t *testing.T) {
testCases := [...]struct {
//overriding instance
val interface{}
//instance overriding methods
methods []string
}{
{&Base{}, []string(nil)},
{&D1{&Base{}}, []string{"M1", "X1"}},
{&D2{Name: "abc", IFace: &D1{&Base{}}}, []string{"M1", "X1", "M2", "X2"}},
}
for _, tc := range testCases {
sort.Slice(tc.methods, func(i, j int) bool {
return tc.methods[i] < tc.methods[j]
})
}
for _, tc := range testCases {
methods := getMethodOverrides(tc.val, "Base")
sort.Slice(methods, func(i, j int) bool {
return methods[i] < methods[j]
})
if !reflect.DeepEqual(methods, tc.methods) {
t.Errorf("expect: %v, got: %v", tc.methods, methods)
}
}
}
107 changes: 91 additions & 16 deletions packages/@jsii/go-runtime/jsii-runtime-go/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ func InitJsiiProxy(ptr interface{}) {
func Create(fqn FQN, args []interface{}, inst interface{}) {
client := kernel.GetClient()

var overrides []api.Override

instVal := reflect.ValueOf(inst)
structVal := instVal.Elem()
instType := structVal.Type()
Expand All @@ -134,20 +132,6 @@ func Create(fqn FQN, args []interface{}, inst interface{}) {
panic(err)
}

if tag := field.Tag.Get("overrides"); tag != "" {
if names := strings.Split(tag, ","); len(names) > 0 {
overrides = make([]api.Override, len(names))
registry := client.Types()
for i, name := range names {
if override, ok := registry.GetOverride(api.FQN(fqn), name); !ok {
panic(fmt.Errorf("unable to identify method or property for override %s declared by %v", name, field))
} else {
overrides[i] = override
}
}
}
}

case reflect.Struct:
fieldVal := structVal.Field(i)
if !fieldVal.IsZero() {
Expand All @@ -159,6 +143,36 @@ func Create(fqn FQN, args []interface{}, inst interface{}) {
}
}

//find method overrides thru reflection
mOverrides := getMethodOverrides(inst, "jsiiProxy_")
//if overriding struct has no overriding methods, could happen if
//overriding methods are not defined with pointer receiver.
if len(mOverrides) == 0 && !strings.HasPrefix(instType.Name(), "jsiiProxy_") {
panic(fmt.Errorf("%s has no overriding methods. Please verify overriding methods are defined with pointer receiver", instType.Name()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
panic(fmt.Errorf("%s has no overriding methods. Please verify overriding methods are defined with pointer receiver", instType.Name()))
panic(fmt.Errorf("%s has no overriding methods. Overriding methods must be defined with a pointer receiver", instType.Name()))

}
var overrides []api.Override
registry := client.Types()
added := make(map[string]bool)
for _, name := range mOverrides {
//use getter's name even if setter is overriden
if strings.HasPrefix(name, "Set") {
propName := name[3:]
if override, ok := registry.GetOverride(api.FQN(fqn), propName); ok {
if !added[propName] {
added[propName] = true
overrides = append(overrides, override)
}
continue
}
}
if override, ok := registry.GetOverride(api.FQN(fqn), name); ok {
if !added[name] {
added[name] = true
overrides = append(overrides, override)
}
}
}

interfaces, newOverrides := client.Types().DiscoverImplementation(instType)
overrides = append(overrides, newOverrides...)

Expand Down Expand Up @@ -361,6 +375,67 @@ func convertArguments(args []interface{}) []interface{} {
return result
}

//Get ptr's methods names which override "base" struct methods.
//The "base" struct is identified by name prefix "basePrefix".
func getMethodOverrides(ptr interface{}, basePrefix string) (methods []string) {
//methods override cache: [methodName]structName
mCache := make(map[string]string)
getMethodOverridesRec(ptr, basePrefix, mCache)
//return overriden methods names in embedding hierarchy
for m, _ := range mCache {
methods = append(methods, m)
}
return
}

func getMethodOverridesRec(ptr interface{}, basePrefix string, cache map[string]string) {
ptrType := reflect.TypeOf(ptr)
if ptrType.Kind() != reflect.Ptr {
return
}
structType := ptrType.Elem()
if structType.Kind() != reflect.Struct {
return
}
if strings.HasPrefix(structType.Name(), basePrefix) {
//skip base class
return
}

ptrVal := reflect.ValueOf(ptr)
structVal := ptrVal.Elem()

//add embedded/super overrides first
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if !field.Anonymous {
continue
}
if field.Type.Kind() == reflect.Ptr ||
field.Type.Kind() == reflect.Interface {
p := structVal.Field(i)
if !p.IsNil() {
getMethodOverridesRec(p.Interface(), basePrefix, cache)
}
}
}
//add overrides in current struct
//current struct's value-type method-set
valMethods := make(map[string]bool)
for i := 0; i < structType.NumMethod(); i++ {
valMethods[structType.Method(i).Name] = true
}
//compare current struct's pointer-type method-set to
//its value-type method-set
structName := structType.Name()
for i := 0; i < ptrType.NumMethod(); i++ {
mn := ptrType.Method(i).Name
if !valMethods[mn] {
cache[mn] = structName
}
}
}

// Close finalizes the runtime process, signalling the end of the execution to
// the jsii kernel process, and waiting for graceful termination. The best
// practice is to defer call thins at the beginning of the "main" function.
Expand Down