diff --git a/crypto/hd/utils_test.go b/crypto/hd/utils_test.go new file mode 100644 index 0000000000..563028b525 --- /dev/null +++ b/crypto/hd/utils_test.go @@ -0,0 +1,182 @@ +// NOTE: This code is being used as test helper functions. +package hd + +import ( + "crypto/ecdsa" + "errors" + "os" + "sync" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + + "github.com/btcsuite/btcd/btcutil/hdkeychain" + "github.com/btcsuite/btcd/chaincfg" + bip39 "github.com/tyler-smith/go-bip39" +) + +const issue179FixEnvar = "GO_ETHEREUM_HDWALLET_FIX_ISSUE_179" + +// Wallet is the underlying wallet struct. +type Wallet struct { + mnemonic string + masterKey *hdkeychain.ExtendedKey + seed []byte + paths map[common.Address]accounts.DerivationPath + accounts []accounts.Account + stateLock sync.RWMutex + fixIssue172 bool +} + +// NewFromMnemonic returns a new wallet from a BIP-39 mnemonic. +func NewFromMnemonic(mnemonic string) (*Wallet, error) { + if mnemonic == "" { + return nil, errors.New("mnemonic is required") + } + + if !bip39.IsMnemonicValid(mnemonic) { + return nil, errors.New("mnemonic is invalid") + } + + seed, err := NewSeedFromMnemonic(mnemonic) + if err != nil { + return nil, err + } + + wallet, err := newWallet(seed) + if err != nil { + return nil, err + } + wallet.mnemonic = mnemonic + + return wallet, nil +} + +// NewSeedFromMnemonic returns a BIP-39 seed based on a BIP-39 mnemonic. +func NewSeedFromMnemonic(mnemonic string) ([]byte, error) { + if mnemonic == "" { + return nil, errors.New("mnemonic is required") + } + + return bip39.NewSeedWithErrorChecking(mnemonic, "") +} + +func newWallet(seed []byte) (*Wallet, error) { + masterKey, err := hdkeychain.NewMaster(seed, &chaincfg.MainNetParams) + if err != nil { + return nil, err + } + + return &Wallet{ + masterKey: masterKey, + seed: seed, + accounts: []accounts.Account{}, + paths: map[common.Address]accounts.DerivationPath{}, + fixIssue172: false || len(os.Getenv(issue179FixEnvar)) > 0, + }, nil +} + +// Derive implements accounts.Wallet, deriving a new account at the specific +// derivation path. If pin is set to true, the account will be added to the list +// of tracked accounts. +func (w *Wallet) Derive(path accounts.DerivationPath, pin bool) (accounts.Account, error) { + // Try to derive the actual account and update its URL if successful + w.stateLock.RLock() // Avoid device disappearing during derivation + + address, err := w.deriveAddress(path) + + w.stateLock.RUnlock() + + // If an error occurred or no pinning was requested, return + if err != nil { + return accounts.Account{}, err + } + + account := accounts.Account{ + Address: address, + URL: accounts.URL{ + Scheme: "", + Path: path.String(), + }, + } + + if !pin { + return account, nil + } + + // Pinning needs to modify the state + w.stateLock.Lock() + defer w.stateLock.Unlock() + + if _, ok := w.paths[address]; !ok { + w.accounts = append(w.accounts, account) + w.paths[address] = path + } + + return account, nil +} + +// MustParseDerivationPath parses the derivation path in string format into +// []uint32 but will panic if it can't parse it. +func MustParseDerivationPath(path string) accounts.DerivationPath { + parsed, err := accounts.ParseDerivationPath(path) + if err != nil { + panic(err) + } + + return parsed +} + +// DerivePrivateKey derives the private key of the derivation path. +func (w *Wallet) derivePrivateKey(path accounts.DerivationPath) (*ecdsa.PrivateKey, error) { + var err error + key := w.masterKey + for _, n := range path { + if w.fixIssue172 && key.IsAffectedByIssue172() { + key, err = key.Derive(n) + } else { + //lint:ignore SA1019 this is used for testing only + //nolint:all + key, err = key.DeriveNonStandard(n) + } + if err != nil { + return nil, err + } + } + + privateKey, err := key.ECPrivKey() + privateKeyECDSA := privateKey.ToECDSA() + if err != nil { + return nil, err + } + + return privateKeyECDSA, nil +} + +// derivePublicKey derives the public key of the derivation path. +func (w *Wallet) derivePublicKey(path accounts.DerivationPath) (*ecdsa.PublicKey, error) { + privateKeyECDSA, err := w.derivePrivateKey(path) + if err != nil { + return nil, err + } + + publicKey := privateKeyECDSA.Public() + publicKeyECDSA, ok := publicKey.(*ecdsa.PublicKey) + if !ok { + return nil, errors.New("failed to get public key") + } + + return publicKeyECDSA, nil +} + +// DeriveAddress derives the account address of the derivation path. +func (w *Wallet) deriveAddress(path accounts.DerivationPath) (common.Address, error) { + publicKeyECDSA, err := w.derivePublicKey(path) + if err != nil { + return common.Address{}, err + } + + address := crypto.PubkeyToAddress(*publicKeyECDSA) + return address, nil +}