Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: corrected mistake when retrieving the reward address #848

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,33 +353,37 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string,
if err != nil {
return nil, nil, err
}
allValAddrs := walletInstance.AllValidatorAddresses()

if len(allValAddrs) < 1 || len(allValAddrs) > 32 {
return nil, nil, fmt.Errorf("number of validators must be between 1 and 32, but it's %d",
len(allValAddrs))
valAddrsInfo := walletInstance.AllValidatorAddresses()
if len(valAddrsInfo) == 0 {
return nil, nil, fmt.Errorf("no validator addresses found in the wallet")
}

if len(valAddrsInfo) > 32 {
PrintWarnMsgf("wallet has more than 32 validator addresses, only the first 32 will be used")
valAddrsInfo = valAddrsInfo[:32]
}

if len(conf.Node.RewardAddresses) > 0 &&
len(conf.Node.RewardAddresses) != len(allValAddrs) {
return nil, nil, fmt.Errorf("reward addresses should be %v", len(allValAddrs))
len(conf.Node.RewardAddresses) != len(valAddrsInfo) {
return nil, nil, fmt.Errorf("reward addresses should be %v", len(valAddrsInfo))
}

validatorAddrs := make([]string, len(allValAddrs))
for i := 0; i < len(validatorAddrs); i++ {
valAddr, _ := crypto.AddressFromString(allValAddrs[i].Address)
valAddrs := make([]string, len(valAddrsInfo))
for i := 0; i < len(valAddrs); i++ {
valAddr, _ := crypto.AddressFromString(valAddrsInfo[i].Address)
if !valAddr.IsValidatorAddress() {
return nil, nil, fmt.Errorf("invalid validator address: %s", allValAddrs[i].Address)
return nil, nil, fmt.Errorf("invalid validator address: %s", valAddrsInfo[i].Address)
}
validatorAddrs[i] = valAddr.String()
valAddrs[i] = valAddr.String()
}

valKeys := make([]*bls.ValidatorKey, len(allValAddrs))
valKeys := make([]*bls.ValidatorKey, len(valAddrsInfo))
password, ok := passwordFetcher(walletInstance)
if !ok {
return nil, nil, fmt.Errorf("aborted")
}
prvKeys, err := walletInstance.PrivateKeys(password, validatorAddrs)
prvKeys, err := walletInstance.PrivateKeys(password, valAddrs)
if err != nil {
return nil, nil, err
}
Expand All @@ -388,28 +392,39 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string,
}

// Create reward addresses
rewardAddrs := make([]crypto.Address, 0, len(allValAddrs))
rewardAddrs := make([]crypto.Address, 0, len(valAddrsInfo))
if len(conf.Node.RewardAddresses) != 0 {
for _, addrStr := range conf.Node.RewardAddresses {
addr, _ := crypto.AddressFromString(addrStr)
rewardAddrs = append(rewardAddrs, addr)
}
} else {
for i := 0; i < len(allValAddrs); i++ {
valAddrPath, _ := addresspath.NewPathFromString(allValAddrs[i].Path)
accAddrPath := addresspath.NewPath(valAddrPath.Purpose(), valAddrPath.CoinType(),
uint32(crypto.AddressTypeValidator)+hdkeychain.HardenedKeyStart, valAddrPath.AddressIndex())
for i := 0; i < len(valAddrsInfo); i++ {
valAddrPath, _ := addresspath.NewPathFromString(valAddrsInfo[i].Path)
accAddrPath := addresspath.NewPath(
valAddrPath.Purpose(),
valAddrPath.CoinType(),
uint32(crypto.AddressTypeBLSAccount)+hdkeychain.HardenedKeyStart,
amirvalhalla marked this conversation as resolved.
Show resolved Hide resolved
valAddrPath.AddressIndex())
b00f marked this conversation as resolved.
Show resolved Hide resolved

addrInfo := walletInstance.AddressFromPath(accAddrPath.String())
if addrInfo == nil {
return nil, nil, fmt.Errorf("unable to find reward address for: %s", allValAddrs[i].Address)
return nil, nil, fmt.Errorf("unable to find reward address for: %s [%s]",
valAddrsInfo[i].Address, accAddrPath)
}

addr, _ := crypto.AddressFromString(addrInfo.Address)
rewardAddrs = append(rewardAddrs, addr)
}
}

// Check if reward addresses are account address
for _, addr := range rewardAddrs {
if !addr.IsAccountAddress() {
return nil, nil, fmt.Errorf("reward address is not an account address: %s", addr)
}
}

nodeInstance, err := node.NewNode(gen, conf, valKeys, rewardAddrs)
if err != nil {
return nil, nil, err
Expand Down
3 changes: 3 additions & 0 deletions wallet/addresspath/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func NewPath(indexes ...uint32) Path {
return p
}

// TODO: check the path should exactly 4 levels.
func NewPathFromString(str string) (Path, error) {
sub := strings.Split(str, "/")
if sub[0] != "m" {
Expand Down Expand Up @@ -52,6 +53,8 @@ func (p Path) String() string {
return builder.String()
}

// TODO: we can add IsBLSPurpose or IsImportedPurpose functions

func (p Path) Purpose() uint32 {
return p[0]
}
Expand Down
7 changes: 7 additions & 0 deletions wallet/vault/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package vault

import (
"github.com/pactus-project/pactus/crypto/bls/hdkeychain"
"github.com/tyler-smith/go-bip39"
"golang.org/x/exp/constraints"
)

// GenerateMnemonic generates a new mnemonic (seed phrase) based on BIP-39
Expand All @@ -22,3 +24,8 @@ func CheckMnemonic(mnemonic string) error {
_, err := bip39.EntropyFromMnemonic(mnemonic)
return err
}

// H hardens the value 'i' by adding it to 0x80000000 (2^31).
func H[T constraints.Integer](i T) uint32 {
return uint32(i) + hdkeychain.HardenedKeyStart
}
72 changes: 27 additions & 45 deletions wallet/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,8 @@ type AddressInfo struct {
}

const (
PurposeBLS12381 = uint32(12381)
HardenedPurposeBLS12381 = PurposeBLS12381 + hdkeychain.HardenedKeyStart
PurposeImportPrivateKey = uint32(65535)
HardenedPurposeImportPrivateKey = PurposeImportPrivateKey + hdkeychain.HardenedKeyStart
HardenedAddressTypeBLSAccount = uint32(crypto.AddressTypeBLSAccount) + hdkeychain.HardenedKeyStart
HardenedAddressTypeValidator = uint32(crypto.AddressTypeValidator) + hdkeychain.HardenedKeyStart
PurposeBLS12381 = uint32(12381)
PurposeImportPrivateKey = uint32(65535)
)

type Vault struct {
Expand Down Expand Up @@ -97,18 +93,18 @@ func CreateVaultFromMnemonic(mnemonic string, coinType uint32) (*Vault, error) {
enc := encrypter.NopeEncrypter()

xPubValidator, err := masterKey.DerivePath([]uint32{
12381 + hdkeychain.HardenedKeyStart,
coinType + hdkeychain.HardenedKeyStart,
uint32(crypto.AddressTypeValidator) + hdkeychain.HardenedKeyStart,
H(PurposeBLS12381),
H(coinType),
H(crypto.AddressTypeValidator),
})
if err != nil {
return nil, err
}

xPubAccount, err := masterKey.DerivePath([]uint32{
12381 + hdkeychain.HardenedKeyStart,
coinType + hdkeychain.HardenedKeyStart,
uint32(crypto.AddressTypeBLSAccount) + hdkeychain.HardenedKeyStart,
H(PurposeBLS12381),
H(coinType),
H(crypto.AddressTypeBLSAccount),
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -222,22 +218,7 @@ func (v *Vault) AllValidatorAddresses() []AddressInfo {
addrs := make([]AddressInfo, 0, v.AddressCount()/2)
for _, addrInfo := range v.Addresses {
addrPath, _ := addresspath.NewPathFromString(addrInfo.Path)
if addrPath.AddressType()-hdkeychain.HardenedKeyStart == uint32(crypto.AddressTypeValidator) {
addrs = append(addrs, addrInfo)
}
}

v.SortAddressesByAddressIndex(addrs...)
v.SortAddressesByPurpose(addrs...)

return addrs
}

func (v *Vault) AllBLSAccountAddresses() []AddressInfo {
addrs := make([]AddressInfo, 0, v.AddressCount()/2)
for _, addrInfo := range v.Addresses {
addrPath, _ := addresspath.NewPathFromString(addrInfo.Path)
if addrPath.AddressType()-hdkeychain.HardenedKeyStart == uint32(crypto.AddressTypeBLSAccount) {
if addrPath.AddressType() == H(crypto.AddressTypeValidator) {
addrs = append(addrs, addrInfo)
}
}
Expand All @@ -252,7 +233,7 @@ func (v *Vault) AllImportedPrivateKeysAddresses() []AddressInfo {
addrs := make([]AddressInfo, 0, v.AddressCount()/2)
for _, addrInfo := range v.Addresses {
addrPath, _ := addresspath.NewPathFromString(addrInfo.Path)
if addrPath.Purpose() == HardenedPurposeImportPrivateKey {
if addrPath.Purpose() == H(PurposeImportPrivateKey) {
addrs = append(addrs, addrInfo)
}
}
Expand Down Expand Up @@ -331,17 +312,17 @@ func (v *Vault) ImportPrivateKey(password string, prv *bls.PrivateKey) error {
return ErrAddressExists
}

blsAccPathStr := addresspath.NewPath(HardenedPurposeImportPrivateKey,
v.CoinType+hdkeychain.HardenedKeyStart,
HardenedAddressTypeBLSAccount,
uint32(addressIndex)+hdkeychain.HardenedKeyStart).
String()
blsAccPathStr := addresspath.NewPath(
H(PurposeImportPrivateKey),
H(v.CoinType),
H(crypto.AddressTypeBLSAccount),
H(addressIndex)).String()

blsValidatorPathStr := addresspath.NewPath(HardenedPurposeImportPrivateKey,
v.CoinType+addresspath.HardenedKeyStart,
HardenedAddressTypeValidator,
uint32(addressIndex)+hdkeychain.HardenedKeyStart).
String()
blsValidatorPathStr := addresspath.NewPath(
H(PurposeImportPrivateKey),
H(v.CoinType),
H(crypto.AddressTypeValidator),
H(addressIndex)).String()

importedPrvLabelCounter := (len(v.AllImportedPrivateKeysAddresses()) / 2) + 1
v.Addresses[accAddr.String()] = AddressInfo{
Expand Down Expand Up @@ -390,12 +371,12 @@ func (v *Vault) PrivateKeys(password string, addrs []string) ([]crypto.PrivateKe
return nil, err
}

if path.CoinType()-hdkeychain.HardenedKeyStart != v.CoinType {
if path.CoinType() != H(v.CoinType) {
return nil, ErrInvalidCoinType
}

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
seed, err := bip39.NewSeedWithErrorChecking(keyStore.MasterNode.Mnemonic, "")
if err != nil {
return nil, err
Expand All @@ -419,8 +400,9 @@ func (v *Vault) PrivateKeys(password string, addrs []string) ([]crypto.PrivateKe
}

keys[i] = prvKey
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
index := path.AddressIndex() - hdkeychain.HardenedKeyStart
// TODO: index out of range check
str := keyStore.ImportedKeys[index]
prv, err := bls.PrivateKeyFromString(str)
if err != nil {
Expand Down Expand Up @@ -504,12 +486,12 @@ func (v *Vault) AddressInfo(addr string) *AddressInfo {
}

// TODO it would be better to return error in future
if path.CoinType()-hdkeychain.HardenedKeyStart != v.CoinType {
if path.CoinType() != H(v.CoinType) {
return nil
}

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
addr, err := crypto.AddressFromString(info.Address)
if err != nil {
return nil
Expand Down Expand Up @@ -543,7 +525,7 @@ func (v *Vault) AddressInfo(addr string) *AddressInfo {
}

info.PublicKey = blsPubKey.String()
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
default:
return nil
}
Expand Down
42 changes: 4 additions & 38 deletions wallet/vault/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestAddressInfo(t *testing.T) {
path, _ := addresspath.NewPathFromString(info.Path)

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
if addr.IsValidatorAddress() {
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
Expand All @@ -94,7 +94,7 @@ func TestAddressInfo(t *testing.T) {
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
}
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
if addr.IsValidatorAddress() {
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d'",
PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart))
Expand Down Expand Up @@ -135,10 +135,10 @@ func TestAllValidatorAddresses(t *testing.T) {
path, _ := addresspath.NewPathFromString(info.Path)

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d'",
PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart))
}
Expand All @@ -155,40 +155,6 @@ func TestSortAllValidatorAddresses(t *testing.T) {
assert.Equal(t, "m/65535'/21888'/1'/0'", validatorAddrs[len(validatorAddrs)-1].Path)
}

func TestAllBLSAccountAddresses(t *testing.T) {
td := setup(t)

assert.Equal(t, td.vault.AddressCount(), 6)

blsAccountAddrs := td.vault.AllBLSAccountAddresses()
for _, i := range blsAccountAddrs {
info := td.vault.AddressInfo(i.Address)
assert.Equal(t, i.Address, info.Address)

path, _ := addresspath.NewPathFromString(info.Path)

switch path.Purpose() {
case HardenedPurposeBLS12381:
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
case HardenedPurposeImportPrivateKey:
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d'",
PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart))
}
}
}

func TestSortAllBLSAccountAddresses(t *testing.T) {
td := setup(t)

assert.Equal(t, td.vault.AddressCount(), 6)

blsAccountAddrs := td.vault.AllBLSAccountAddresses()

assert.Equal(t, "m/12381'/21888'/2'/0", blsAccountAddrs[0].Path)
assert.Equal(t, "m/65535'/21888'/2'/0'", blsAccountAddrs[len(blsAccountAddrs)-1].Path)
}

func TestAddressFromPath(t *testing.T) {
td := setup(t)
assert.Equal(t, td.vault.AddressCount(), 6)
Expand Down
25 changes: 17 additions & 8 deletions wallet/wallet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,25 @@ func TestInvalidAddress(t *testing.T) {
assert.Error(t, err)
}

// TODO: Fix later.
// func TestImportPrivateKey(t *testing.T) {
// td := setup(t)
func TestImportPrivateKey(t *testing.T) {
td := setup(t)

_, prv := td.RandBLSKeyPair()
assert.NoError(t, td.wallet.ImportPrivateKey(td.password, prv))

pub := prv.PublicKeyNative()
accAddr := pub.AccountAddress().String()
valAddr := pub.AccountAddress().String()

// _, prv := td.RandBLSKeyPair()
// assert.NoError(t, td.wallet.ImportPrivateKey(td.password, prv))
assert.True(t, td.wallet.Contains(accAddr))
assert.True(t, td.wallet.Contains(valAddr))

// addr := prv.PublicKeyNative().AccountAddress().String()
// assert.True(t, td.wallet.Contains(addr))
// }
accAddrInfo := td.wallet.AddressInfo(accAddr)
valAddrInfo := td.wallet.AddressInfo(accAddr)

assert.Equal(t, pub.String(), accAddrInfo.PublicKey)
assert.Equal(t, pub.String(), valAddrInfo.PublicKey)
}

func TestKeyInfo(t *testing.T) {
td := setup(t)
Expand Down
Loading