Skip to content

Commit

Permalink
Merge dashpay#6441: fix: hold wallet shared pointer in CJ Manager/Ses…
Browse files Browse the repository at this point in the history
…sions to prevent concurrent unload

2d7c7f8 fix: do not transfer wallet ownership to CTransactionBuilder{Output} (UdjinM6)
0aeeb85 fix: add missing `AddWallet` call in `TestLoadWallet` (UdjinM6)
e800d9d fix: hold wallet shared pointer in CJ Manager/Sessions to prevent concurrent unload (UdjinM6)

Pull request description:

  ## Issue being fixed or feature implemented
  dashpay#6440 (comment)

  ## What was done?

  ## How Has This Been Tested?

  ## Breaking Changes

  ## Checklist:
  - [ ] I have performed a self-review of my own code
  - [ ] I have commented my code, particularly in hard-to-understand areas
  - [ ] I have added or updated relevant unit/integration/functional/e2e tests
  - [ ] I have made corresponding changes to the documentation
  - [ ] I have assigned this pull request to a milestone _(for repository code-owners and collaborators only)_

ACKs for top commit:
  PastaPastaPasta:
    utACK 2d7c7f8

Tree-SHA512: 308e3bed077baa2167b7f9d81b87e5a61a113e4d465706548f303dfc499bc072d4e823e85772e591a879986b0fb0413d5afe0e3995e1f939fa772b29adc0300d
  • Loading branch information
PastaPastaPasta committed Dec 3, 2024
1 parent c074e09 commit c7b0d80
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 127 deletions.
132 changes: 58 additions & 74 deletions src/coinjoin/client.cpp

Large diffs are not rendered by default.

32 changes: 20 additions & 12 deletions src/coinjoin/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class CoinJoinWalletManager {
}
}

void Add(CWallet& wallet);
void Add(const std::shared_ptr<CWallet>& wallet);
void DoMaintenance();

void Remove(const std::string& name);
Expand Down Expand Up @@ -138,7 +138,7 @@ class CoinJoinWalletManager {
class CCoinJoinClientSession : public CCoinJoinBaseSession
{
private:
CWallet& m_wallet;
const std::shared_ptr<CWallet> m_wallet;
CoinJoinWalletManager& m_walletman;
CCoinJoinClientManager& m_clientman;
CDeterministicMNManager& m_dmnman;
Expand All @@ -163,15 +163,15 @@ class CCoinJoinClientSession : public CCoinJoinBaseSession
/// Create denominations
bool CreateDenominated(CAmount nBalanceToDenominate);
bool CreateDenominated(CAmount nBalanceToDenominate, const CompactTallyItem& tallyItem, bool fCreateMixingCollaterals)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);

/// Split up large inputs or make fee sized inputs
bool MakeCollateralAmounts();
bool MakeCollateralAmounts(const CompactTallyItem& tallyItem, bool fTryDenominated)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);

bool CreateCollateralTransaction(CMutableTransaction& txCollateral, std::string& strReason)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);

bool JoinExistingQueue(CAmount nBalanceNeedsAnonymized, CConnman& connman);
bool StartNewQueue(CAmount nBalanceNeedsAnonymized, CConnman& connman);
Expand All @@ -181,7 +181,7 @@ class CCoinJoinClientSession : public CCoinJoinBaseSession
/// step 1: prepare denominated inputs and outputs
bool PrepareDenominate(int nMinRounds, int nMaxRounds, std::string& strErrorRet, const std::vector<CTxDSIn>& vecTxDSIn,
std::vector<std::pair<CTxDSIn, CTxOut>>& vecPSInOutPairsRet, bool fDryRun = false)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);
/// step 2: send denominated inputs and outputs prepared in step 1
bool SendDenominate(const std::vector<std::pair<CTxDSIn, CTxOut> >& vecPSInOutPairsIn, CConnman& connman) EXCLUSIVE_LOCKS_REQUIRED(!cs_coinjoin);

Expand All @@ -200,7 +200,7 @@ class CCoinJoinClientSession : public CCoinJoinBaseSession
void SetNull() override EXCLUSIVE_LOCKS_REQUIRED(cs_coinjoin);

public:
explicit CCoinJoinClientSession(CWallet& wallet, CoinJoinWalletManager& walletman,
explicit CCoinJoinClientSession(const std::shared_ptr<CWallet>& wallet, CoinJoinWalletManager& walletman,
CCoinJoinClientManager& clientman, CDeterministicMNManager& dmnman,
CMasternodeMetaMan& mn_metaman, const CMasternodeSync& mn_sync,
const std::unique_ptr<CCoinJoinClientQueueManager>& queueman, bool is_masternode);
Expand Down Expand Up @@ -267,7 +267,7 @@ class CCoinJoinClientQueueManager : public CCoinJoinBaseManager
class CCoinJoinClientManager
{
private:
CWallet& m_wallet;
const std::shared_ptr<CWallet> m_wallet;
CoinJoinWalletManager& m_walletman;
CDeterministicMNManager& m_dmnman;
CMasternodeMetaMan& m_mn_metaman;
Expand Down Expand Up @@ -306,11 +306,19 @@ class CCoinJoinClientManager
CCoinJoinClientManager(CCoinJoinClientManager const&) = delete;
CCoinJoinClientManager& operator=(CCoinJoinClientManager const&) = delete;

explicit CCoinJoinClientManager(CWallet& wallet, CoinJoinWalletManager& walletman, CDeterministicMNManager& dmnman,
CMasternodeMetaMan& mn_metaman, const CMasternodeSync& mn_sync,
explicit CCoinJoinClientManager(const std::shared_ptr<CWallet>& wallet, CoinJoinWalletManager& walletman,
CDeterministicMNManager& dmnman, CMasternodeMetaMan& mn_metaman,
const CMasternodeSync& mn_sync,
const std::unique_ptr<CCoinJoinClientQueueManager>& queueman, bool is_masternode) :
m_wallet(wallet), m_walletman(walletman), m_dmnman(dmnman), m_mn_metaman(mn_metaman), m_mn_sync(mn_sync), m_queueman(queueman),
m_is_masternode{is_masternode} {}
m_wallet(wallet),
m_walletman(walletman),
m_dmnman(dmnman),
m_mn_metaman(mn_metaman),
m_mn_sync(mn_sync),
m_queueman(queueman),
m_is_masternode{is_masternode}
{
}

void ProcessMessage(CNode& peer, CChainState& active_chainstate, CConnman& connman, const CTxMemPool& mempool, std::string_view msg_type, CDataStream& vRecv) EXCLUSIVE_LOCKS_REQUIRED(!cs_deqsessions);

Expand Down
5 changes: 1 addition & 4 deletions src/coinjoin/interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ class CoinJoinLoaderImpl : public interfaces::CoinJoin::Loader
explicit CoinJoinLoaderImpl(CoinJoinWalletManager& walletman)
: m_walletman(walletman) {}

void AddWallet(CWallet& wallet) override
{
m_walletman.Add(wallet);
}
void AddWallet(const std::shared_ptr<CWallet>& wallet) override { m_walletman.Add(wallet); }
void RemoveWallet(const std::string& name) override
{
m_walletman.Remove(name);
Expand Down
41 changes: 21 additions & 20 deletions src/coinjoin/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,15 @@ void CKeyHolderStorage::ReturnAll()
}
}

CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn) :
CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn,
const std::shared_ptr<CWallet>& wallet, CAmount nAmountIn) :
pTxBuilder(pTxBuilderIn),
dest(pwalletIn.get()),
dest(wallet.get()),
nAmount(nAmountIn)
{
assert(pTxBuilder);
CTxDestination txdest;
LOCK(pwalletIn->cs_wallet);
LOCK(wallet->cs_wallet);
dest.GetReservedDestination(txdest, false);
script = ::GetScriptForDestination(txdest);
}
Expand All @@ -108,15 +109,15 @@ bool CTransactionBuilderOutput::UpdateAmount(const CAmount nNewAmount)
return true;
}

CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn) :
pwallet(pwalletIn),
dummyReserveDestination(pwalletIn.get()),
CTransactionBuilder::CTransactionBuilder(const std::shared_ptr<CWallet>& wallet, const CompactTallyItem& tallyItemIn) :
m_wallet(wallet),
dummyReserveDestination(wallet.get()),
tallyItem(tallyItemIn)
{
// Generate a feerate which will be used to consider if the remainder is dust and will go into fees or not
coinControl.m_discard_feerate = ::GetDiscardRate(*pwallet);
coinControl.m_discard_feerate = ::GetDiscardRate(*m_wallet);
// Generate a feerate which will be used by calculations of this class and also by CWallet::CreateTransaction
coinControl.m_feerate = std::max(GetRequiredFeeRate(*pwallet), pwallet->m_pay_tx_fee);
coinControl.m_feerate = std::max(GetRequiredFeeRate(*m_wallet), m_wallet->m_pay_tx_fee);
// Change always goes back to origin
coinControl.destChange = tallyItemIn.txdest;
// Only allow tallyItems inputs for tx creation
Expand All @@ -131,16 +132,16 @@ CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, con
// Get a comparable dummy scriptPubKey, avoid writing/flushing to the actual wallet db
CScript dummyScript;
{
LOCK(pwallet->cs_wallet);
WalletBatch dummyBatch(pwallet->GetDatabase(), false);
LOCK(m_wallet->cs_wallet);
WalletBatch dummyBatch(m_wallet->GetDatabase(), false);
dummyBatch.TxnBegin();
CKey secret;
secret.MakeNewKey(pwallet->CanSupportFeature(FEATURE_COMPRPUBKEY));
secret.MakeNewKey(m_wallet->CanSupportFeature(FEATURE_COMPRPUBKEY));
CPubKey dummyPubkey = secret.GetPubKey();
dummyBatch.TxnAbort();
dummyScript = ::GetScriptForDestination(PKHash(dummyPubkey));
// Calculate required bytes for the dummy signed tx with tallyItem's inputs only
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), pwallet.get(), false);
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), m_wallet.get(), false);
}
// Calculate the output size
nBytesOutput = ::GetSerializeSize(CTxOut(0, dummyScript), PROTOCOL_VERSION);
Expand Down Expand Up @@ -204,7 +205,7 @@ CTransactionBuilderOutput* CTransactionBuilder::AddOutput(CAmount nAmountOutput)
{
if (CouldAddOutput(nAmountOutput)) {
LOCK(cs_outputs);
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, pwallet, nAmountOutput));
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, m_wallet, nAmountOutput));
return vecOutputs.back().get();
}
return nullptr;
Expand Down Expand Up @@ -233,12 +234,12 @@ CAmount CTransactionBuilder::GetAmountUsed() const
CAmount CTransactionBuilder::GetFee(unsigned int nBytes) const
{
CAmount nFeeCalc = coinControl.m_feerate->GetFee(nBytes);
CAmount nRequiredFee = GetRequiredFee(*pwallet, nBytes);
CAmount nRequiredFee = GetRequiredFee(*m_wallet, nBytes);
if (nRequiredFee > nFeeCalc) {
nFeeCalc = nRequiredFee;
}
if (nFeeCalc > pwallet->m_default_max_tx_fee) {
nFeeCalc = pwallet->m_default_max_tx_fee;
if (nFeeCalc > m_wallet->m_default_max_tx_fee) {
nFeeCalc = m_wallet->m_default_max_tx_fee;
}
return nFeeCalc;
}
Expand Down Expand Up @@ -273,9 +274,9 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)

CTransactionRef tx;
{
LOCK2(pwallet->cs_wallet, cs_main);
LOCK2(m_wallet->cs_wallet, cs_main);
FeeCalculation fee_calc_out;
if (!pwallet->CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
if (!m_wallet->CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
return false;
}
}
Expand Down Expand Up @@ -312,8 +313,8 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)
}

{
LOCK2(pwallet->cs_wallet, cs_main);
pwallet->CommitTransaction(tx, {}, {});
LOCK2(m_wallet->cs_wallet, cs_main);
m_wallet->CommitTransaction(tx, {}, {});
}

fKeepKeys = true;
Expand Down
6 changes: 3 additions & 3 deletions src/coinjoin/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class CTransactionBuilderOutput
CScript script;

public:
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, const std::shared_ptr<CWallet>& wallet, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilderOutput&&) = delete;
CTransactionBuilderOutput& operator=(CTransactionBuilderOutput&&) = delete;
/// Get the scriptPubKey of this output
Expand All @@ -77,7 +77,7 @@ class CTransactionBuilderOutput
class CTransactionBuilder
{
/// Wallet the transaction will be build for
std::shared_ptr<CWallet> pwallet;
const std::shared_ptr<CWallet>& m_wallet;
/// See CTransactionBuilder() for initialization
CCoinControl coinControl;
/// Dummy since we anyway use tallyItem's destination as change destination in coincontrol.
Expand All @@ -100,7 +100,7 @@ class CTransactionBuilder
friend class CTransactionBuilderOutput;

public:
CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn);
CTransactionBuilder(const std::shared_ptr<CWallet>& wallet, const CompactTallyItem& tallyItemIn);
~CTransactionBuilder();
/// Check it would be possible to add a single output with the amount nAmount. Returns true if its possible and false if not.
bool CouldAddOutput(CAmount nAmountOutput) const EXCLUSIVE_LOCKS_REQUIRED(!cs_outputs);
Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/coinjoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Loader
public:
virtual ~Loader() {}
//! Add new wallet to CoinJoin client manager
virtual void AddWallet(CWallet&) = 0;
virtual void AddWallet(const std::shared_ptr<CWallet>&) = 0;
//! Remove wallet from CoinJoin client manager
virtual void RemoveWallet(const std::string&) = 0;
virtual void FlushWallet(const std::string&) = 0;
Expand Down
4 changes: 4 additions & 0 deletions src/wallet/test/wallet_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ static std::shared_ptr<CWallet> TestLoadWallet(interfaces::Chain* chain, interfa
std::vector<bilingual_str> warnings;
auto database = MakeWalletDatabase("", options, status, error);
auto wallet = CWallet::Create(chain, coinjoin_loader, "", std::move(database), options.create_flags, error, warnings);
if (coinjoin_loader) {
// TODO: see CreateWalletWithoutChain
AddWallet(wallet);
}
if (chain) {
wallet->postInitProcess();
}
Expand Down
24 changes: 12 additions & 12 deletions src/wallet/wallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ bool AddWallet(const std::shared_ptr<CWallet>& wallet)
}
wallet->ConnectScriptPubKeyManNotifiers();
wallet->AutoLockMasternodeCollaterals();
wallet->coinjoin_loader().AddWallet(*wallet);
wallet->coinjoin_loader().AddWallet(wallet);
wallet->NotifyCanGetAddressesChanged();
return true;
}
Expand Down Expand Up @@ -1432,46 +1432,46 @@ int CWallet::GetRealOutpointCoinJoinRounds(const COutPoint& outpoint, int nRound
if (wtx == nullptr || wtx->tx == nullptr) {
// no such tx in this wallet
*nRoundsRef = -1;
WalletCJLogPrint((*this), "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -1);
WalletCJLogPrint(this, "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -1);
return *nRoundsRef;
}

// bounds check
if (outpoint.n >= wtx->tx->vout.size()) {
// should never actually hit this
*nRoundsRef = -4;
WalletCJLogPrint((*this), "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -4);
WalletCJLogPrint(this, "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -4);
return *nRoundsRef;
}

auto txOutRef = &wtx->tx->vout[outpoint.n];

if (CoinJoin::IsCollateralAmount(txOutRef->nValue)) {
*nRoundsRef = -3;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

// make sure the final output is non-denominate
if (!CoinJoin::IsDenominatedAmount(txOutRef->nValue)) { //NOT DENOM
*nRoundsRef = -2;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

for (const auto& out : wtx->tx->vout) {
if (!CoinJoin::IsDenominatedAmount(out.nValue)) {
// this one is denominated but there is another non-denominated output found in the same tx
*nRoundsRef = 0;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}
}

// make sure we spent all of it with 0 fee, reset to 0 rounds otherwise
if (wtx->GetDebit(ISMINE_SPENDABLE) != wtx->GetCredit(ISMINE_SPENDABLE)) {
*nRoundsRef = 0;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

Expand All @@ -1491,7 +1491,7 @@ int CWallet::GetRealOutpointCoinJoinRounds(const COutPoint& outpoint, int nRound
*nRoundsRef = fDenomFound
? (nShortest >= nRoundsMax - 1 ? nRoundsMax : nShortest + 1) // good, we a +1 to the shortest one but only nRoundsMax rounds max allowed
: 0; // too bad, we are the fist one in that chain
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

Expand Down Expand Up @@ -3253,7 +3253,7 @@ bool CWallet::SelectTxDSInsByDenomination(int nDenom, CAmount nValueMax, std::ve
CCoinControl coin_control;
coin_control.nCoinType = CoinType::ONLY_READY_TO_MIX;
AvailableCoins(vCoins, &coin_control);
WalletCJLogPrint((*this), "CWallet::%s -- vCoins.size(): %d\n", __func__, vCoins.size());
WalletCJLogPrint(this, "CWallet::%s -- vCoins.size(): %d\n", __func__, vCoins.size());

Shuffle(vCoins.rbegin(), vCoins.rend(), FastRandomContext());

Expand All @@ -3271,11 +3271,11 @@ bool CWallet::SelectTxDSInsByDenomination(int nDenom, CAmount nValueMax, std::ve
nValueTotal += nValue;
vecTxDSInRet.emplace_back(CTxDSIn(txin, scriptPubKey, nRounds));
setRecentTxIds.emplace(txHash);
WalletCJLogPrint((*this), "CWallet::%s -- hash: %s, nValue: %d.%08d\n",
WalletCJLogPrint(this, "CWallet::%s -- hash: %s, nValue: %d.%08d\n",
__func__, txHash.ToString(), nValue / COIN, nValue % COIN);
}

WalletCJLogPrint((*this), "CWallet::%s -- setRecentTxIds.size(): %d\n", __func__, setRecentTxIds.size());
WalletCJLogPrint(this, "CWallet::%s -- setRecentTxIds.size(): %d\n", __func__, setRecentTxIds.size());

return nValueTotal > 0;
}
Expand Down Expand Up @@ -4980,7 +4980,7 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain* chain, interfaces::C
}

if (coinjoin_loader) {
coinjoin_loader->AddWallet(*walletInstance);
coinjoin_loader->AddWallet(walletInstance);
}

{
Expand Down
2 changes: 1 addition & 1 deletion src/wallet/wallet.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ extern const std::map<uint64_t,std::string> WALLET_FLAG_CAVEATS;
#define WalletCJLogPrint(wallet, ...) \
do { \
if (LogAcceptCategory(BCLog::COINJOIN)) { \
wallet.WalletLogPrintf(__VA_ARGS__); \
wallet->WalletLogPrintf(__VA_ARGS__); \
} \
} while (0)

Expand Down

0 comments on commit c7b0d80

Please sign in to comment.