From b1e5fa33c4b5967cf3624c057a28d47cf3043985 Mon Sep 17 00:00:00 2001 From: Mostafa Date: Mon, 11 Dec 2023 19:09:21 +0800 Subject: [PATCH] fix: corrected mistake when retrieving the reward address --- cmd/cmd.go | 53 ++++++++++++++++++---------- wallet/addresspath/path.go | 3 ++ wallet/vault/utils.go | 7 ++++ wallet/vault/vault.go | 72 ++++++++++++++------------------------ wallet/vault/vault_test.go | 42 +++------------------- wallet/wallet_test.go | 25 ++++++++----- 6 files changed, 92 insertions(+), 110 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index c213d5ca2..f8e9e0205 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -350,33 +350,37 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string, if err != nil { return nil, nil, err } - allValAddrs := walletInstance.AllValidatorAddresses() - if len(allValAddrs) < 1 || len(allValAddrs) > 32 { - return nil, nil, fmt.Errorf("number of validators must be between 1 and 32, but it's %d", - len(allValAddrs)) + valAddrsInfo := walletInstance.AllValidatorAddresses() + if len(valAddrsInfo) == 0 { + return nil, nil, fmt.Errorf("no validator addresses found in the wallet") + } + + if len(valAddrsInfo) > 32 { + PrintWarnMsgf("wallet has more than 32 validator addresses, only the first 32 will be used") + valAddrsInfo = valAddrsInfo[:32] } if len(conf.Node.RewardAddresses) > 0 && - len(conf.Node.RewardAddresses) != len(allValAddrs) { - return nil, nil, fmt.Errorf("reward addresses should be %v", len(allValAddrs)) + len(conf.Node.RewardAddresses) != len(valAddrsInfo) { + return nil, nil, fmt.Errorf("reward addresses should be %v", len(valAddrsInfo)) } - validatorAddrs := make([]string, len(allValAddrs)) - for i := 0; i < len(validatorAddrs); i++ { - valAddr, _ := crypto.AddressFromString(allValAddrs[i].Address) + valAddrs := make([]string, len(valAddrsInfo)) + for i := 0; i < len(valAddrs); i++ { + valAddr, _ := crypto.AddressFromString(valAddrsInfo[i].Address) if !valAddr.IsValidatorAddress() { - return nil, nil, fmt.Errorf("invalid validator address: %s", allValAddrs[i].Address) + return nil, nil, fmt.Errorf("invalid validator address: %s", valAddrsInfo[i].Address) } - validatorAddrs[i] = valAddr.String() + valAddrs[i] = valAddr.String() } - valKeys := make([]*bls.ValidatorKey, len(allValAddrs)) + valKeys := make([]*bls.ValidatorKey, len(valAddrsInfo)) password, ok := passwordFetcher(walletInstance) if !ok { return nil, nil, fmt.Errorf("aborted") } - prvKeys, err := walletInstance.PrivateKeys(password, validatorAddrs) + prvKeys, err := walletInstance.PrivateKeys(password, valAddrs) if err != nil { return nil, nil, err } @@ -385,21 +389,25 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string, } // Create reward addresses - rewardAddrs := make([]crypto.Address, 0, len(allValAddrs)) + rewardAddrs := make([]crypto.Address, 0, len(valAddrsInfo)) if len(conf.Node.RewardAddresses) != 0 { for _, addrStr := range conf.Node.RewardAddresses { addr, _ := crypto.AddressFromString(addrStr) rewardAddrs = append(rewardAddrs, addr) } } else { - for i := 0; i < len(allValAddrs); i++ { - valAddrPath, _ := addresspath.NewPathFromString(allValAddrs[i].Path) - accAddrPath := addresspath.NewPath(valAddrPath.Purpose(), valAddrPath.CoinType(), - uint32(crypto.AddressTypeValidator)+hdkeychain.HardenedKeyStart, valAddrPath.AddressIndex()) + for i := 0; i < len(valAddrsInfo); i++ { + valAddrPath, _ := addresspath.NewPathFromString(valAddrsInfo[i].Path) + accAddrPath := addresspath.NewPath( + valAddrPath.Purpose(), + valAddrPath.CoinType(), + uint32(crypto.AddressTypeBLSAccount)+hdkeychain.HardenedKeyStart, + valAddrPath.AddressIndex()) addrInfo := walletInstance.AddressFromPath(accAddrPath.String()) if addrInfo == nil { - return nil, nil, fmt.Errorf("unable to find reward address for: %s", allValAddrs[i].Address) + return nil, nil, fmt.Errorf("unable to find reward address for: %s [%s]", + valAddrsInfo[i].Address, accAddrPath) } addr, _ := crypto.AddressFromString(addrInfo.Address) @@ -407,6 +415,13 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string, } } + // Check if reward addresses are account address + for _, addr := range rewardAddrs { + if !addr.IsAccountAddress() { + return nil, nil, fmt.Errorf("reward address is not an account address: %s", addr) + } + } + nodeInstance, err := node.NewNode(gen, conf, valKeys, rewardAddrs) if err != nil { return nil, nil, err diff --git a/wallet/addresspath/path.go b/wallet/addresspath/path.go index 0e5d159bb..9e9457f40 100644 --- a/wallet/addresspath/path.go +++ b/wallet/addresspath/path.go @@ -17,6 +17,7 @@ func NewPath(indexes ...uint32) Path { return p } +// TODO: check the path should exactly 4 levels. func NewPathFromString(str string) (Path, error) { sub := strings.Split(str, "/") if sub[0] != "m" { @@ -52,6 +53,8 @@ func (p Path) String() string { return builder.String() } +// TODO: we can add IsBLSPurpose or IsImportedPurpose functions + func (p Path) Purpose() uint32 { return p[0] } diff --git a/wallet/vault/utils.go b/wallet/vault/utils.go index 3d97c5fc4..82394df17 100644 --- a/wallet/vault/utils.go +++ b/wallet/vault/utils.go @@ -1,7 +1,9 @@ package vault import ( + "github.com/pactus-project/pactus/crypto/bls/hdkeychain" "github.com/tyler-smith/go-bip39" + "golang.org/x/exp/constraints" ) // GenerateMnemonic generates a new mnemonic (seed phrase) based on BIP-39 @@ -22,3 +24,8 @@ func CheckMnemonic(mnemonic string) error { _, err := bip39.EntropyFromMnemonic(mnemonic) return err } + +// H hardens the value 'i' by adding it to 0x80000000 (2^31). +func H[T constraints.Integer](i T) uint32 { + return uint32(i) + hdkeychain.HardenedKeyStart +} diff --git a/wallet/vault/vault.go b/wallet/vault/vault.go index 436b354b7..9483faa70 100644 --- a/wallet/vault/vault.go +++ b/wallet/vault/vault.go @@ -48,12 +48,8 @@ type AddressInfo struct { } const ( - PurposeBLS12381 = uint32(12381) - HardenedPurposeBLS12381 = PurposeBLS12381 + hdkeychain.HardenedKeyStart - PurposeImportPrivateKey = uint32(65535) - HardenedPurposeImportPrivateKey = PurposeImportPrivateKey + hdkeychain.HardenedKeyStart - HardenedAddressTypeBLSAccount = uint32(crypto.AddressTypeBLSAccount) + hdkeychain.HardenedKeyStart - HardenedAddressTypeValidator = uint32(crypto.AddressTypeValidator) + hdkeychain.HardenedKeyStart + PurposeBLS12381 = uint32(12381) + PurposeImportPrivateKey = uint32(65535) ) type Vault struct { @@ -97,18 +93,18 @@ func CreateVaultFromMnemonic(mnemonic string, coinType uint32) (*Vault, error) { enc := encrypter.NopeEncrypter() xPubValidator, err := masterKey.DerivePath([]uint32{ - 12381 + hdkeychain.HardenedKeyStart, - coinType + hdkeychain.HardenedKeyStart, - uint32(crypto.AddressTypeValidator) + hdkeychain.HardenedKeyStart, + H(PurposeBLS12381), + H(coinType), + H(crypto.AddressTypeValidator), }) if err != nil { return nil, err } xPubAccount, err := masterKey.DerivePath([]uint32{ - 12381 + hdkeychain.HardenedKeyStart, - coinType + hdkeychain.HardenedKeyStart, - uint32(crypto.AddressTypeBLSAccount) + hdkeychain.HardenedKeyStart, + H(PurposeBLS12381), + H(coinType), + H(crypto.AddressTypeBLSAccount), }) if err != nil { return nil, err @@ -222,22 +218,7 @@ func (v *Vault) AllValidatorAddresses() []AddressInfo { addrs := make([]AddressInfo, 0, v.AddressCount()/2) for _, addrInfo := range v.Addresses { addrPath, _ := addresspath.NewPathFromString(addrInfo.Path) - if addrPath.AddressType()-hdkeychain.HardenedKeyStart == uint32(crypto.AddressTypeValidator) { - addrs = append(addrs, addrInfo) - } - } - - v.SortAddressesByAddressIndex(addrs...) - v.SortAddressesByPurpose(addrs...) - - return addrs -} - -func (v *Vault) AllBLSAccountAddresses() []AddressInfo { - addrs := make([]AddressInfo, 0, v.AddressCount()/2) - for _, addrInfo := range v.Addresses { - addrPath, _ := addresspath.NewPathFromString(addrInfo.Path) - if addrPath.AddressType()-hdkeychain.HardenedKeyStart == uint32(crypto.AddressTypeBLSAccount) { + if addrPath.AddressType() == H(crypto.AddressTypeValidator) { addrs = append(addrs, addrInfo) } } @@ -252,7 +233,7 @@ func (v *Vault) AllImportedPrivateKeysAddresses() []AddressInfo { addrs := make([]AddressInfo, 0, v.AddressCount()/2) for _, addrInfo := range v.Addresses { addrPath, _ := addresspath.NewPathFromString(addrInfo.Path) - if addrPath.Purpose() == HardenedPurposeImportPrivateKey { + if addrPath.Purpose() == H(PurposeImportPrivateKey) { addrs = append(addrs, addrInfo) } } @@ -331,17 +312,17 @@ func (v *Vault) ImportPrivateKey(password string, prv *bls.PrivateKey) error { return ErrAddressExists } - blsAccPathStr := addresspath.NewPath(HardenedPurposeImportPrivateKey, - v.CoinType+hdkeychain.HardenedKeyStart, - HardenedAddressTypeBLSAccount, - uint32(addressIndex)+hdkeychain.HardenedKeyStart). - String() + blsAccPathStr := addresspath.NewPath( + H(PurposeImportPrivateKey), + H(v.CoinType), + H(crypto.AddressTypeBLSAccount), + H(addressIndex)).String() - blsValidatorPathStr := addresspath.NewPath(HardenedPurposeImportPrivateKey, - v.CoinType+addresspath.HardenedKeyStart, - HardenedAddressTypeValidator, - uint32(addressIndex)+hdkeychain.HardenedKeyStart). - String() + blsValidatorPathStr := addresspath.NewPath( + H(PurposeImportPrivateKey), + H(v.CoinType), + H(crypto.AddressTypeValidator), + H(addressIndex)).String() importedPrvLabelCounter := (len(v.AllImportedPrivateKeysAddresses()) / 2) + 1 v.Addresses[accAddr.String()] = AddressInfo{ @@ -390,12 +371,12 @@ func (v *Vault) PrivateKeys(password string, addrs []string) ([]crypto.PrivateKe return nil, err } - if path.CoinType()-hdkeychain.HardenedKeyStart != v.CoinType { + if path.CoinType() != H(v.CoinType) { return nil, ErrInvalidCoinType } switch path.Purpose() { - case HardenedPurposeBLS12381: + case H(PurposeBLS12381): seed, err := bip39.NewSeedWithErrorChecking(keyStore.MasterNode.Mnemonic, "") if err != nil { return nil, err @@ -419,8 +400,9 @@ func (v *Vault) PrivateKeys(password string, addrs []string) ([]crypto.PrivateKe } keys[i] = prvKey - case HardenedPurposeImportPrivateKey: + case H(PurposeImportPrivateKey): index := path.AddressIndex() - hdkeychain.HardenedKeyStart + // TODO: index out of range check str := keyStore.ImportedKeys[index] prv, err := bls.PrivateKeyFromString(str) if err != nil { @@ -504,12 +486,12 @@ func (v *Vault) AddressInfo(addr string) *AddressInfo { } // TODO it would be better to return error in future - if path.CoinType()-hdkeychain.HardenedKeyStart != v.CoinType { + if path.CoinType() != H(v.CoinType) { return nil } switch path.Purpose() { - case HardenedPurposeBLS12381: + case H(PurposeBLS12381): addr, err := crypto.AddressFromString(info.Address) if err != nil { return nil @@ -543,7 +525,7 @@ func (v *Vault) AddressInfo(addr string) *AddressInfo { } info.PublicKey = blsPubKey.String() - case HardenedPurposeImportPrivateKey: + case H(PurposeImportPrivateKey): default: return nil } diff --git a/wallet/vault/vault_test.go b/wallet/vault/vault_test.go index 232edaef5..c82776341 100644 --- a/wallet/vault/vault_test.go +++ b/wallet/vault/vault_test.go @@ -84,7 +84,7 @@ func TestAddressInfo(t *testing.T) { path, _ := addresspath.NewPathFromString(info.Path) switch path.Purpose() { - case HardenedPurposeBLS12381: + case H(PurposeBLS12381): if addr.IsValidatorAddress() { assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d", PurposeBLS12381, td.vault.CoinType, path.AddressIndex())) @@ -94,7 +94,7 @@ func TestAddressInfo(t *testing.T) { assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d", PurposeBLS12381, td.vault.CoinType, path.AddressIndex())) } - case HardenedPurposeImportPrivateKey: + case H(PurposeImportPrivateKey): if addr.IsValidatorAddress() { assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d'", PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart)) @@ -135,10 +135,10 @@ func TestAllValidatorAddresses(t *testing.T) { path, _ := addresspath.NewPathFromString(info.Path) switch path.Purpose() { - case HardenedPurposeBLS12381: + case H(PurposeBLS12381): assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d", PurposeBLS12381, td.vault.CoinType, path.AddressIndex())) - case HardenedPurposeImportPrivateKey: + case H(PurposeImportPrivateKey): assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d'", PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart)) } @@ -155,40 +155,6 @@ func TestSortAllValidatorAddresses(t *testing.T) { assert.Equal(t, "m/65535'/21888'/1'/0'", validatorAddrs[len(validatorAddrs)-1].Path) } -func TestAllBLSAccountAddresses(t *testing.T) { - td := setup(t) - - assert.Equal(t, td.vault.AddressCount(), 6) - - blsAccountAddrs := td.vault.AllBLSAccountAddresses() - for _, i := range blsAccountAddrs { - info := td.vault.AddressInfo(i.Address) - assert.Equal(t, i.Address, info.Address) - - path, _ := addresspath.NewPathFromString(info.Path) - - switch path.Purpose() { - case HardenedPurposeBLS12381: - assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d", - PurposeBLS12381, td.vault.CoinType, path.AddressIndex())) - case HardenedPurposeImportPrivateKey: - assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d'", - PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart)) - } - } -} - -func TestSortAllBLSAccountAddresses(t *testing.T) { - td := setup(t) - - assert.Equal(t, td.vault.AddressCount(), 6) - - blsAccountAddrs := td.vault.AllBLSAccountAddresses() - - assert.Equal(t, "m/12381'/21888'/2'/0", blsAccountAddrs[0].Path) - assert.Equal(t, "m/65535'/21888'/2'/0'", blsAccountAddrs[len(blsAccountAddrs)-1].Path) -} - func TestAddressFromPath(t *testing.T) { td := setup(t) assert.Equal(t, td.vault.AddressCount(), 6) diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index 873acfd36..d40ca173d 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -182,16 +182,25 @@ func TestInvalidAddress(t *testing.T) { assert.Error(t, err) } -// TODO: Fix later. -// func TestImportPrivateKey(t *testing.T) { -// td := setup(t) +func TestImportPrivateKey(t *testing.T) { + td := setup(t) + + _, prv := td.RandBLSKeyPair() + assert.NoError(t, td.wallet.ImportPrivateKey(td.password, prv)) + + pub := prv.PublicKeyNative() + accAddr := pub.AccountAddress().String() + valAddr := pub.AccountAddress().String() -// _, prv := td.RandBLSKeyPair() -// assert.NoError(t, td.wallet.ImportPrivateKey(td.password, prv)) + assert.True(t, td.wallet.Contains(accAddr)) + assert.True(t, td.wallet.Contains(valAddr)) -// addr := prv.PublicKeyNative().AccountAddress().String() -// assert.True(t, td.wallet.Contains(addr)) -// } + accAddrInfo := td.wallet.AddressInfo(accAddr) + valAddrInfo := td.wallet.AddressInfo(accAddr) + + assert.Equal(t, pub.String(), accAddrInfo.PublicKey) + assert.Equal(t, pub.String(), valAddrInfo.PublicKey) +} func TestKeyInfo(t *testing.T) { td := setup(t)