Skip to content

Commit

Permalink
fix: update GetMsgTypes and cleanTypesAndMsgValue functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonberg1997 authored and forcodedancing committed May 17, 2023
1 parent 4f32f44 commit d6ce237
Showing 1 changed file with 77 additions and 90 deletions.
167 changes: 77 additions & 90 deletions x/auth/tx/eip712.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,38 +114,62 @@ func GetMsgTypes(signerData signing.SignerData, tx sdk.Tx, typedChainID *big.Int
}

// extract the msg types
msgTypes := apitypes.Types{}
msgTypes := apitypes.Types{
"EIP712Domain": {
{
Name: "name",
Type: "string",
},
{
Name: "version",
Type: "string",
},
{
Name: "chainId",
Type: "uint256",
},
{
Name: "verifyingContract",
Type: "string",
},
{
Name: "salt",
Type: "string",
},
},
"Tx": {
{Name: "account_number", Type: "uint256"},
{Name: "chain_id", Type: "uint256"},
{Name: "fee", Type: "Fee"},
{Name: "memo", Type: "string"},
{Name: "sequence", Type: "uint256"},
{Name: "timeout_height", Type: "uint256"},
},
"Fee": {
{Name: "amount", Type: "Coin[]"},
{Name: "gas_limit", Type: "uint256"},
{Name: "payer", Type: "string"},
{Name: "granter", Type: "string"},
},
"Coin": {
{Name: "denom", Type: "string"},
{Name: "amount", Type: "uint256"},
},
}
for i, msg := range protoTx.GetMsgs() {
tmpMsgTypes, err := extractMsgTypes(msg)
tmpMsgTypes, err := extractMsgTypes(msg, i+1)
if err != nil {
return nil, nil, err
}

if len(msgTypes) == 0 {
msgTypes = tmpMsgTypes
}
msgTypes["Tx"] = append(msgTypes["Tx"], apitypes.Type{
Name: fmt.Sprintf("msg%d", i+1),
Type: fmt.Sprintf("Msg%d", i+1),
})

for _, t := range tmpMsgTypes["Tx"] {
if t.Name == "msg" {
if i == 0 {
for idx, _ := range msgTypes["Tx"] {
if msgTypes["Tx"][idx].Name == "msg" {
msgTypes["Tx"][idx] = apitypes.Type{
Name: fmt.Sprintf("msg%d", i+1),
Type: fmt.Sprintf("Msg%d", i+1),
}
}
}
} else {
msgTypes["Tx"] = append(msgTypes["Tx"], apitypes.Type{
Name: fmt.Sprintf("msg%d", i+1),
Type: fmt.Sprintf("Msg%d", i+1),
})
}
}
for key, field := range tmpMsgTypes {
msgTypes[key] = field
}
msgTypes[fmt.Sprintf("Msg%d", i+1)] = tmpMsgTypes["Msg"]
delete(msgTypes, "Msg")
}

// patch the msg types to include `Tip` if it's not empty
Expand Down Expand Up @@ -201,9 +225,9 @@ func WrapTxToTypedData(
delete(txData, "tip")
}

// filling nil value and do other clean up
// filling nil value and do the other clean up
msgData := txData["msg"].([]interface{})
for i, _ := range signDoc.GetMsg() {
for i := range signDoc.GetMsg() {
txData[fmt.Sprintf("msg%d", i+1)] = msgData[i]
cleanTypesAndMsgValue(msgTypes, fmt.Sprintf("Msg%d", i+1), msgData[i].(map[string]interface{}))
}
Expand All @@ -221,55 +245,10 @@ func WrapTxToTypedData(
return typedData, nil
}

func extractMsgTypes(msg sdk.Msg) (apitypes.Types, error) {
rootTypes := apitypes.Types{
"EIP712Domain": {
{
Name: "name",
Type: "string",
},
{
Name: "version",
Type: "string",
},
{
Name: "chainId",
Type: "uint256",
},
{
Name: "verifyingContract",
Type: "string",
},
{
Name: "salt",
Type: "string",
},
},
"Tx": {
{Name: "account_number", Type: "uint256"},
{Name: "chain_id", Type: "uint256"},
{Name: "fee", Type: "Fee"},
{Name: "memo", Type: "string"},
{Name: "msg", Type: "Msg"},
{Name: "sequence", Type: "uint256"},
{Name: "timeout_height", Type: "uint256"},
},
"Fee": {
{Name: "amount", Type: "Coin[]"},
{Name: "gas_limit", Type: "uint256"},
{Name: "payer", Type: "string"},
{Name: "granter", Type: "string"},
},
"Coin": {
{Name: "denom", Type: "string"},
{Name: "amount", Type: "uint256"},
},
"Msg": {
{Name: "type", Type: "string"},
},
}
func extractMsgTypes(msg sdk.Msg, index int) (apitypes.Types, error) {
rootTypes := apitypes.Types{}

if err := walkFields(rootTypes, msg); err != nil {
if err := walkFields(rootTypes, msg, index); err != nil {
return nil, err
}

Expand All @@ -278,7 +257,7 @@ func extractMsgTypes(msg sdk.Msg) (apitypes.Types, error) {

const typeDefPrefix = "_"

func walkFields(typeMap apitypes.Types, in interface{}) (err error) {
func walkFields(typeMap apitypes.Types, in interface{}, index int) (err error) {
defer doRecover(&err)

t := reflect.TypeOf(in)
Expand All @@ -296,7 +275,7 @@ func walkFields(typeMap apitypes.Types, in interface{}) (err error) {
break
}

return traverseFields(typeMap, typeDefPrefix, t, v)
return traverseFields(typeMap, typeDefPrefix, index, t, v)
}

type anyWrapper struct {
Expand All @@ -307,6 +286,7 @@ type anyWrapper struct {
func traverseFields(
typeMap apitypes.Types,
prefix string,
index int,
t reflect.Type,
v reflect.Value,
) error {
Expand Down Expand Up @@ -378,12 +358,13 @@ func traverseFields(
}

if prefix == typeDefPrefix {
typeMap["Msg"] = append(typeMap["Msg"], apitypes.Type{
tag := fmt.Sprintf("Msg%d", index)
typeMap[tag] = append(typeMap[tag], apitypes.Type{
Name: fieldName,
Type: ethTyp,
})
} else {
typeDef := sanitizeTypedef(prefix)
typeDef := sanitizeTypedef(prefix, index)
typeMap[typeDef] = append(typeMap[typeDef], apitypes.Type{
Name: fieldName,
Type: ethTyp,
Expand All @@ -397,25 +378,26 @@ func traverseFields(
var fieldTypedef string

if isCollection {
fieldTypedef = sanitizeTypedef(fieldPrefix) + "[]"
fieldTypedef = sanitizeTypedef(fieldPrefix, index) + "[]"
} else {
fieldTypedef = sanitizeTypedef(fieldPrefix)
fieldTypedef = sanitizeTypedef(fieldPrefix, index)
}

if prefix == typeDefPrefix {
typeMap["Msg"] = append(typeMap["Msg"], apitypes.Type{
tag := fmt.Sprintf("Msg%d", index)
typeMap[tag] = append(typeMap[tag], apitypes.Type{
Name: fieldName,
Type: fieldTypedef,
})
} else {
typeDef := sanitizeTypedef(prefix)
typeDef := sanitizeTypedef(prefix, index)
typeMap[typeDef] = append(typeMap[typeDef], apitypes.Type{
Name: fieldName,
Type: fieldTypedef,
})
}

if err := traverseFields(typeMap, fieldPrefix, fieldType, field); err != nil {
if err := traverseFields(typeMap, fieldPrefix, index, fieldType, field); err != nil {
return err
}
continue
Expand Down Expand Up @@ -472,17 +454,22 @@ func cleanTypesAndMsgValue(typedData apitypes.Types, primaryType string, msgValu
newValue["value"] = bz
newAnySet[i] = newValue
}
msgValue[encName] = newAnySet
msgValue[encName[:len(encName)-3]] = newAnySet
typedData[primaryType][i].Name = encName[:len(encName)-3]
typedData[primaryType][i].Type = "TypeAny[]"
delete(typedData, encType[:len(encType)-2])
} else {
anyValue := msgValue[encName[:len(encName)-3]].(map[string]interface{})
newValue := make(map[string]interface{})
bz, _ := json.Marshal(anyValue)
newValue["type"] = anyValue["@type"]
newValue["value"] = bz
msgValue[encName] = newValue
msgValue[encName[:len(encName)-3]] = newValue
typedData[primaryType][i].Name = encName[:len(encName)-3]
typedData[primaryType][i].Type = "TypeAny"
delete(typedData, encType)
}
typedData[encType] = anyApiTypes
delete(msgValue, encName[:len(encName)-3])
typedData["TypeAny"] = anyApiTypes
continue
}
encValue := msgValue[encName]
Expand Down Expand Up @@ -593,14 +580,14 @@ func unwrapField(fieldType reflect.Type, field reflect.Value, fieldName string)
//
// this is needed for Geth's own signing code which doesn't
// tolerate complex type names
func sanitizeTypedef(str string) string {
func sanitizeTypedef(str string, index int) string {
buf := new(bytes.Buffer)
parts := strings.Split(str, ".")
caser := cases.Title(language.English, cases.NoLower)

for _, part := range parts {
if part == "_" {
buf.WriteString("Type")
buf.WriteString(fmt.Sprintf("TypeMsg%d", index))
continue
}

Expand Down

0 comments on commit d6ce237

Please sign in to comment.