Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: hold wallet shared pointer in CJ Manager/Sessions to prevent concurrent unload #6441

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 58 additions & 74 deletions src/coinjoin/client.cpp

Large diffs are not rendered by default.

32 changes: 20 additions & 12 deletions src/coinjoin/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class CoinJoinWalletManager {
}
}

void Add(CWallet& wallet);
void Add(const std::shared_ptr<CWallet>& wallet);
void DoMaintenance();

void Remove(const std::string& name);
Expand Down Expand Up @@ -138,7 +138,7 @@ class CoinJoinWalletManager {
class CCoinJoinClientSession : public CCoinJoinBaseSession
{
private:
CWallet& m_wallet;
const std::shared_ptr<CWallet> m_wallet;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively,
if CCoinJoinClientSession and CCoinJoinClientManager is now an owner of CWallet, smart pointer still can be removed from CTransactionBuilder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, with this PR we actually have guarantee that CWallet won't be freed during any CCoinJoinClientSession and CCoinJoinClientManager operation. Let's do a final touch then: 4f73191 to simplify code a bit further

CoinJoinWalletManager& m_walletman;
CCoinJoinClientManager& m_clientman;
CDeterministicMNManager& m_dmnman;
Expand All @@ -163,15 +163,15 @@ class CCoinJoinClientSession : public CCoinJoinBaseSession
/// Create denominations
bool CreateDenominated(CAmount nBalanceToDenominate);
bool CreateDenominated(CAmount nBalanceToDenominate, const CompactTallyItem& tallyItem, bool fCreateMixingCollaterals)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);

/// Split up large inputs or make fee sized inputs
bool MakeCollateralAmounts();
bool MakeCollateralAmounts(const CompactTallyItem& tallyItem, bool fTryDenominated)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);

bool CreateCollateralTransaction(CMutableTransaction& txCollateral, std::string& strReason)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);

bool JoinExistingQueue(CAmount nBalanceNeedsAnonymized, CConnman& connman);
bool StartNewQueue(CAmount nBalanceNeedsAnonymized, CConnman& connman);
Expand All @@ -181,7 +181,7 @@ class CCoinJoinClientSession : public CCoinJoinBaseSession
/// step 1: prepare denominated inputs and outputs
bool PrepareDenominate(int nMinRounds, int nMaxRounds, std::string& strErrorRet, const std::vector<CTxDSIn>& vecTxDSIn,
std::vector<std::pair<CTxDSIn, CTxOut>>& vecPSInOutPairsRet, bool fDryRun = false)
EXCLUSIVE_LOCKS_REQUIRED(m_wallet.cs_wallet);
EXCLUSIVE_LOCKS_REQUIRED(m_wallet->cs_wallet);
/// step 2: send denominated inputs and outputs prepared in step 1
bool SendDenominate(const std::vector<std::pair<CTxDSIn, CTxOut> >& vecPSInOutPairsIn, CConnman& connman) EXCLUSIVE_LOCKS_REQUIRED(!cs_coinjoin);

Expand All @@ -200,7 +200,7 @@ class CCoinJoinClientSession : public CCoinJoinBaseSession
void SetNull() override EXCLUSIVE_LOCKS_REQUIRED(cs_coinjoin);

public:
explicit CCoinJoinClientSession(CWallet& wallet, CoinJoinWalletManager& walletman,
explicit CCoinJoinClientSession(const std::shared_ptr<CWallet>& wallet, CoinJoinWalletManager& walletman,
CCoinJoinClientManager& clientman, CDeterministicMNManager& dmnman,
CMasternodeMetaMan& mn_metaman, const CMasternodeSync& mn_sync,
const std::unique_ptr<CCoinJoinClientQueueManager>& queueman, bool is_masternode);
Expand Down Expand Up @@ -267,7 +267,7 @@ class CCoinJoinClientQueueManager : public CCoinJoinBaseManager
class CCoinJoinClientManager
{
private:
CWallet& m_wallet;
const std::shared_ptr<CWallet> m_wallet;
CoinJoinWalletManager& m_walletman;
CDeterministicMNManager& m_dmnman;
CMasternodeMetaMan& m_mn_metaman;
Expand Down Expand Up @@ -306,11 +306,19 @@ class CCoinJoinClientManager
CCoinJoinClientManager(CCoinJoinClientManager const&) = delete;
CCoinJoinClientManager& operator=(CCoinJoinClientManager const&) = delete;

explicit CCoinJoinClientManager(CWallet& wallet, CoinJoinWalletManager& walletman, CDeterministicMNManager& dmnman,
CMasternodeMetaMan& mn_metaman, const CMasternodeSync& mn_sync,
explicit CCoinJoinClientManager(const std::shared_ptr<CWallet>& wallet, CoinJoinWalletManager& walletman,
CDeterministicMNManager& dmnman, CMasternodeMetaMan& mn_metaman,
const CMasternodeSync& mn_sync,
const std::unique_ptr<CCoinJoinClientQueueManager>& queueman, bool is_masternode) :
m_wallet(wallet), m_walletman(walletman), m_dmnman(dmnman), m_mn_metaman(mn_metaman), m_mn_sync(mn_sync), m_queueman(queueman),
m_is_masternode{is_masternode} {}
m_wallet(wallet),
m_walletman(walletman),
m_dmnman(dmnman),
m_mn_metaman(mn_metaman),
m_mn_sync(mn_sync),
m_queueman(queueman),
m_is_masternode{is_masternode}
{
}

void ProcessMessage(CNode& peer, CChainState& active_chainstate, CConnman& connman, const CTxMemPool& mempool, std::string_view msg_type, CDataStream& vRecv) EXCLUSIVE_LOCKS_REQUIRED(!cs_deqsessions);

Expand Down
5 changes: 1 addition & 4 deletions src/coinjoin/interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ class CoinJoinLoaderImpl : public interfaces::CoinJoin::Loader
explicit CoinJoinLoaderImpl(CoinJoinWalletManager& walletman)
: m_walletman(walletman) {}

void AddWallet(CWallet& wallet) override
{
m_walletman.Add(wallet);
}
void AddWallet(const std::shared_ptr<CWallet>& wallet) override { m_walletman.Add(wallet); }
void RemoveWallet(const std::string& name) override
{
m_walletman.Remove(name);
Expand Down
41 changes: 21 additions & 20 deletions src/coinjoin/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,15 @@ void CKeyHolderStorage::ReturnAll()
}
}

CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn) :
CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn,
const std::shared_ptr<CWallet>& wallet, CAmount nAmountIn) :
pTxBuilder(pTxBuilderIn),
dest(pwalletIn.get()),
dest(wallet.get()),
nAmount(nAmountIn)
{
assert(pTxBuilder);
CTxDestination txdest;
LOCK(pwalletIn->cs_wallet);
LOCK(wallet->cs_wallet);
dest.GetReservedDestination(txdest, false);
script = ::GetScriptForDestination(txdest);
}
Expand All @@ -108,15 +109,15 @@ bool CTransactionBuilderOutput::UpdateAmount(const CAmount nNewAmount)
return true;
}

CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn) :
pwallet(pwalletIn),
dummyReserveDestination(pwalletIn.get()),
CTransactionBuilder::CTransactionBuilder(const std::shared_ptr<CWallet>& wallet, const CompactTallyItem& tallyItemIn) :
m_wallet(wallet),
dummyReserveDestination(wallet.get()),
tallyItem(tallyItemIn)
{
// Generate a feerate which will be used to consider if the remainder is dust and will go into fees or not
coinControl.m_discard_feerate = ::GetDiscardRate(*pwallet);
coinControl.m_discard_feerate = ::GetDiscardRate(*m_wallet);
// Generate a feerate which will be used by calculations of this class and also by CWallet::CreateTransaction
coinControl.m_feerate = std::max(GetRequiredFeeRate(*pwallet), pwallet->m_pay_tx_fee);
coinControl.m_feerate = std::max(GetRequiredFeeRate(*m_wallet), m_wallet->m_pay_tx_fee);
// Change always goes back to origin
coinControl.destChange = tallyItemIn.txdest;
// Only allow tallyItems inputs for tx creation
Expand All @@ -131,16 +132,16 @@ CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, con
// Get a comparable dummy scriptPubKey, avoid writing/flushing to the actual wallet db
CScript dummyScript;
{
LOCK(pwallet->cs_wallet);
WalletBatch dummyBatch(pwallet->GetDatabase(), false);
LOCK(m_wallet->cs_wallet);
WalletBatch dummyBatch(m_wallet->GetDatabase(), false);
dummyBatch.TxnBegin();
CKey secret;
secret.MakeNewKey(pwallet->CanSupportFeature(FEATURE_COMPRPUBKEY));
secret.MakeNewKey(m_wallet->CanSupportFeature(FEATURE_COMPRPUBKEY));
CPubKey dummyPubkey = secret.GetPubKey();
dummyBatch.TxnAbort();
dummyScript = ::GetScriptForDestination(PKHash(dummyPubkey));
// Calculate required bytes for the dummy signed tx with tallyItem's inputs only
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), pwallet.get(), false);
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), m_wallet.get(), false);
}
// Calculate the output size
nBytesOutput = ::GetSerializeSize(CTxOut(0, dummyScript), PROTOCOL_VERSION);
Expand Down Expand Up @@ -204,7 +205,7 @@ CTransactionBuilderOutput* CTransactionBuilder::AddOutput(CAmount nAmountOutput)
{
if (CouldAddOutput(nAmountOutput)) {
LOCK(cs_outputs);
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, pwallet, nAmountOutput));
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, m_wallet, nAmountOutput));
return vecOutputs.back().get();
}
return nullptr;
Expand Down Expand Up @@ -233,12 +234,12 @@ CAmount CTransactionBuilder::GetAmountUsed() const
CAmount CTransactionBuilder::GetFee(unsigned int nBytes) const
{
CAmount nFeeCalc = coinControl.m_feerate->GetFee(nBytes);
CAmount nRequiredFee = GetRequiredFee(*pwallet, nBytes);
CAmount nRequiredFee = GetRequiredFee(*m_wallet, nBytes);
if (nRequiredFee > nFeeCalc) {
nFeeCalc = nRequiredFee;
}
if (nFeeCalc > pwallet->m_default_max_tx_fee) {
nFeeCalc = pwallet->m_default_max_tx_fee;
if (nFeeCalc > m_wallet->m_default_max_tx_fee) {
nFeeCalc = m_wallet->m_default_max_tx_fee;
}
return nFeeCalc;
}
Expand Down Expand Up @@ -273,9 +274,9 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)

CTransactionRef tx;
{
LOCK2(pwallet->cs_wallet, cs_main);
LOCK2(m_wallet->cs_wallet, cs_main);
FeeCalculation fee_calc_out;
if (!pwallet->CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
if (!m_wallet->CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
return false;
}
}
Expand Down Expand Up @@ -312,8 +313,8 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)
}

{
LOCK2(pwallet->cs_wallet, cs_main);
pwallet->CommitTransaction(tx, {}, {});
LOCK2(m_wallet->cs_wallet, cs_main);
m_wallet->CommitTransaction(tx, {}, {});
}

fKeepKeys = true;
Expand Down
6 changes: 3 additions & 3 deletions src/coinjoin/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class CTransactionBuilderOutput
CScript script;

public:
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, const std::shared_ptr<CWallet>& wallet, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilderOutput&&) = delete;
CTransactionBuilderOutput& operator=(CTransactionBuilderOutput&&) = delete;
/// Get the scriptPubKey of this output
Expand All @@ -77,7 +77,7 @@ class CTransactionBuilderOutput
class CTransactionBuilder
{
/// Wallet the transaction will be build for
std::shared_ptr<CWallet> pwallet;
const std::shared_ptr<CWallet>& m_wallet;
/// See CTransactionBuilder() for initialization
CCoinControl coinControl;
/// Dummy since we anyway use tallyItem's destination as change destination in coincontrol.
Expand All @@ -100,7 +100,7 @@ class CTransactionBuilder
friend class CTransactionBuilderOutput;

public:
CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn);
CTransactionBuilder(const std::shared_ptr<CWallet>& wallet, const CompactTallyItem& tallyItemIn);
~CTransactionBuilder();
/// Check it would be possible to add a single output with the amount nAmount. Returns true if its possible and false if not.
bool CouldAddOutput(CAmount nAmountOutput) const EXCLUSIVE_LOCKS_REQUIRED(!cs_outputs);
Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/coinjoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Loader
public:
virtual ~Loader() {}
//! Add new wallet to CoinJoin client manager
virtual void AddWallet(CWallet&) = 0;
virtual void AddWallet(const std::shared_ptr<CWallet>&) = 0;
//! Remove wallet from CoinJoin client manager
virtual void RemoveWallet(const std::string&) = 0;
virtual void FlushWallet(const std::string&) = 0;
Expand Down
4 changes: 4 additions & 0 deletions src/wallet/test/wallet_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ static std::shared_ptr<CWallet> TestLoadWallet(interfaces::Chain* chain, interfa
std::vector<bilingual_str> warnings;
auto database = MakeWalletDatabase("", options, status, error);
auto wallet = CWallet::Create(chain, coinjoin_loader, "", std::move(database), options.create_flags, error, warnings);
if (coinjoin_loader) {
// TODO: see CreateWalletWithoutChain
AddWallet(wallet);
}
if (chain) {
wallet->postInitProcess();
}
Expand Down
24 changes: 12 additions & 12 deletions src/wallet/wallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ bool AddWallet(const std::shared_ptr<CWallet>& wallet)
}
wallet->ConnectScriptPubKeyManNotifiers();
wallet->AutoLockMasternodeCollaterals();
wallet->coinjoin_loader().AddWallet(*wallet);
wallet->coinjoin_loader().AddWallet(wallet);
wallet->NotifyCanGetAddressesChanged();
return true;
}
Expand Down Expand Up @@ -1432,46 +1432,46 @@ int CWallet::GetRealOutpointCoinJoinRounds(const COutPoint& outpoint, int nRound
if (wtx == nullptr || wtx->tx == nullptr) {
// no such tx in this wallet
*nRoundsRef = -1;
WalletCJLogPrint((*this), "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -1);
WalletCJLogPrint(this, "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -1);
return *nRoundsRef;
}

// bounds check
if (outpoint.n >= wtx->tx->vout.size()) {
// should never actually hit this
*nRoundsRef = -4;
WalletCJLogPrint((*this), "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -4);
WalletCJLogPrint(this, "%s FAILED %-70s %3d\n", __func__, outpoint.ToStringShort(), -4);
return *nRoundsRef;
}

auto txOutRef = &wtx->tx->vout[outpoint.n];

if (CoinJoin::IsCollateralAmount(txOutRef->nValue)) {
*nRoundsRef = -3;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

// make sure the final output is non-denominate
if (!CoinJoin::IsDenominatedAmount(txOutRef->nValue)) { //NOT DENOM
*nRoundsRef = -2;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

for (const auto& out : wtx->tx->vout) {
if (!CoinJoin::IsDenominatedAmount(out.nValue)) {
// this one is denominated but there is another non-denominated output found in the same tx
*nRoundsRef = 0;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}
}

// make sure we spent all of it with 0 fee, reset to 0 rounds otherwise
if (wtx->GetDebit(ISMINE_SPENDABLE) != wtx->GetCredit(ISMINE_SPENDABLE)) {
*nRoundsRef = 0;
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

Expand All @@ -1491,7 +1491,7 @@ int CWallet::GetRealOutpointCoinJoinRounds(const COutPoint& outpoint, int nRound
*nRoundsRef = fDenomFound
? (nShortest >= nRoundsMax - 1 ? nRoundsMax : nShortest + 1) // good, we a +1 to the shortest one but only nRoundsMax rounds max allowed
: 0; // too bad, we are the fist one in that chain
WalletCJLogPrint((*this), "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
WalletCJLogPrint(this, "%s UPDATED %-70s %3d\n", __func__, outpoint.ToStringShort(), *nRoundsRef);
return *nRoundsRef;
}

Expand Down Expand Up @@ -3253,7 +3253,7 @@ bool CWallet::SelectTxDSInsByDenomination(int nDenom, CAmount nValueMax, std::ve
CCoinControl coin_control;
coin_control.nCoinType = CoinType::ONLY_READY_TO_MIX;
AvailableCoins(vCoins, &coin_control);
WalletCJLogPrint((*this), "CWallet::%s -- vCoins.size(): %d\n", __func__, vCoins.size());
WalletCJLogPrint(this, "CWallet::%s -- vCoins.size(): %d\n", __func__, vCoins.size());

Shuffle(vCoins.rbegin(), vCoins.rend(), FastRandomContext());

Expand All @@ -3271,11 +3271,11 @@ bool CWallet::SelectTxDSInsByDenomination(int nDenom, CAmount nValueMax, std::ve
nValueTotal += nValue;
vecTxDSInRet.emplace_back(CTxDSIn(txin, scriptPubKey, nRounds));
setRecentTxIds.emplace(txHash);
WalletCJLogPrint((*this), "CWallet::%s -- hash: %s, nValue: %d.%08d\n",
WalletCJLogPrint(this, "CWallet::%s -- hash: %s, nValue: %d.%08d\n",
__func__, txHash.ToString(), nValue / COIN, nValue % COIN);
}

WalletCJLogPrint((*this), "CWallet::%s -- setRecentTxIds.size(): %d\n", __func__, setRecentTxIds.size());
WalletCJLogPrint(this, "CWallet::%s -- setRecentTxIds.size(): %d\n", __func__, setRecentTxIds.size());

return nValueTotal > 0;
}
Expand Down Expand Up @@ -4980,7 +4980,7 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain* chain, interfaces::C
}

if (coinjoin_loader) {
coinjoin_loader->AddWallet(*walletInstance);
coinjoin_loader->AddWallet(walletInstance);
}

{
Expand Down
2 changes: 1 addition & 1 deletion src/wallet/wallet.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ extern const std::map<uint64_t,std::string> WALLET_FLAG_CAVEATS;
#define WalletCJLogPrint(wallet, ...) \
do { \
if (LogAcceptCategory(BCLog::COINJOIN, BCLog::Level::Debug)) { \
wallet.WalletLogPrintf(__VA_ARGS__); \
wallet->WalletLogPrintf(__VA_ARGS__); \
} \
} while (0)

Expand Down
Loading