diff --git a/apps/CMakeLists.txt b/apps/CMakeLists.txt index 554ca13f1955..4ce116ae1b01 100644 --- a/apps/CMakeLists.txt +++ b/apps/CMakeLists.txt @@ -158,6 +158,9 @@ if (BUILD_APPS) add_executable(gdalasyncread EXCLUDE_FROM_ALL gdalasyncread.cpp) add_executable(gdalwarpsimple EXCLUDE_FROM_ALL gdalwarpsimple.c) add_executable(multireadtest EXCLUDE_FROM_ALL multireadtest.cpp) + if(NOT MSVC AND CMAKE_THREAD_LIBS_INIT) + target_link_libraries(multireadtest PRIVATE ${CMAKE_THREAD_LIBS_INIT}) + endif() add_executable(test_ogrsf test_ogrsf.cpp) add_executable(testreprojmulti EXCLUDE_FROM_ALL testreprojmulti.cpp) diff --git a/apps/multireadtest.cpp b/apps/multireadtest.cpp index d8a4e1cc24b8..e0c737bcf521 100644 --- a/apps/multireadtest.cpp +++ b/apps/multireadtest.cpp @@ -30,27 +30,32 @@ #include "gdal_alg.h" #include "cpl_multiproc.h" #include "cpl_string.h" + +#include +#include +#include #include static int nIterations = 1; static bool bLockOnOpen = false; static int nOpenIterations = 1; static volatile int nPendingThreads = 0; +static bool bThreadCanFinish = false; +static std::mutex oMutex; +static std::condition_variable oCond; static const char *pszFilename = nullptr; static int nChecksum = 0; static int nWidth = 0; static int nHeight = 0; -static CPLMutex *pGlobalMutex = nullptr; - /************************************************************************/ /* Usage() */ /************************************************************************/ static void Usage() { - printf("multireadtest [-lock_on_open] [-open_in_main] [-t ]\n" - " [-i ] [-oi ]\n" + printf("multireadtest [[-thread_safe] | [[-lock_on_open] [-open_in_main]]\n" + " [-t ] [-i ] [-oi ]\n" " [-width ] [-height ]\n" " filename\n"); exit(1); @@ -74,12 +79,12 @@ static void WorkerFunc(void *arg) else { if (bLockOnOpen) - CPLAcquireMutex(pGlobalMutex, 100.0); + oMutex.lock(); hDS = GDALOpen(pszFilename, GA_ReadOnly); if (bLockOnOpen) - CPLReleaseMutex(pGlobalMutex); + oMutex.unlock(); } for (int iIter = 0; iIter < nIterations && hDS != nullptr; iIter++) @@ -99,10 +104,10 @@ static void WorkerFunc(void *arg) if (hDS && hDSIn == nullptr) { if (bLockOnOpen) - CPLAcquireMutex(pGlobalMutex, 100.0); + oMutex.lock(); GDALClose(hDS); if (bLockOnOpen) - CPLReleaseMutex(pGlobalMutex); + oMutex.unlock(); } else if (hDSIn != nullptr) { @@ -110,9 +115,13 @@ static void WorkerFunc(void *arg) } } - CPLAcquireMutex(pGlobalMutex, 100.0); - nPendingThreads--; - CPLReleaseMutex(pGlobalMutex); + { + std::unique_lock oLock(oMutex); + nPendingThreads--; + oCond.notify_all(); + while (!bThreadCanFinish) + oCond.wait(oLock); + } } /************************************************************************/ @@ -131,6 +140,10 @@ int main(int argc, char **argv) int nThreadCount = 4; bool bOpenInThreads = true; + bool bThreadSafe = false; + bool bJoinAfterClosing = false; + bool bDetach = false; + bool bClose = true; for (int iArg = 1; iArg < argc; iArg++) { @@ -154,6 +167,10 @@ int main(int argc, char **argv) { nHeight = atoi(argv[++iArg]); } + else if (EQUAL(argv[iArg], "-thread_safe")) + { + bThreadSafe = true; + } else if (EQUAL(argv[iArg], "-lock_on_open")) { bLockOnOpen = true; @@ -162,6 +179,18 @@ int main(int argc, char **argv) { bOpenInThreads = false; } + else if (EQUAL(argv[iArg], "-join_after_closing")) + { + bJoinAfterClosing = true; + } + else if (EQUAL(argv[iArg], "-detach")) + { + bDetach = true; + } + else if (EQUAL(argv[iArg], "-do_not_close")) + { + bClose = false; + } else if (pszFilename == nullptr) { pszFilename = argv[iArg]; @@ -186,12 +215,10 @@ int main(int argc, char **argv) /* -------------------------------------------------------------------- */ /* Get the checksum of band1. */ /* -------------------------------------------------------------------- */ - GDALDatasetH hDS = nullptr; - GDALAllRegister(); for (int i = 0; i < 2; i++) { - hDS = GDALOpen(pszFilename, GA_ReadOnly); + GDALDatasetH hDS = GDALOpen(pszFilename, GA_ReadOnly); if (hDS == nullptr) exit(1); @@ -210,39 +237,83 @@ int main(int argc, char **argv) /* -------------------------------------------------------------------- */ /* Fire off worker threads. */ /* -------------------------------------------------------------------- */ - pGlobalMutex = CPLCreateMutex(); - CPLReleaseMutex(pGlobalMutex); nPendingThreads = nThreadCount; + GDALDatasetH hThreadSafeDS = nullptr; + if (bThreadSafe) + { + hThreadSafeDS = + GDALOpenEx(pszFilename, GDAL_OF_RASTER | GDAL_OF_THREAD_SAFE, + nullptr, nullptr, nullptr); + if (!hThreadSafeDS) + exit(1); + } + std::vector aoThreads; std::vector aoDS; for (int iThread = 0; iThread < nThreadCount; iThread++) { - hDS = nullptr; - if (!bOpenInThreads) + GDALDatasetH hDS = nullptr; + if (bThreadSafe) { - hDS = GDALOpen(pszFilename, GA_ReadOnly); - if (!hDS) + hDS = hThreadSafeDS; + } + else + { + if (!bOpenInThreads) { - printf("GDALOpen() failed.\n"); - exit(1); + hDS = GDALOpen(pszFilename, GA_ReadOnly); + if (!hDS) + { + printf("GDALOpen() failed.\n"); + exit(1); + } + aoDS.push_back(hDS); } - aoDS.push_back(hDS); } - if (CPLCreateThread(WorkerFunc, hDS) == -1) + aoThreads.push_back(std::thread([hDS]() { WorkerFunc(hDS); })); + } + + { + std::unique_lock oLock(oMutex); + while (nPendingThreads > 0) { - printf("CPLCreateThread() failed.\n"); - exit(1); + // printf("nPendingThreads = %d\n", nPendingThreads); + oCond.wait(oLock); } } - while (nPendingThreads > 0) - CPLSleep(0.5); - - CPLDestroyMutex(pGlobalMutex); + if (!bJoinAfterClosing && !bDetach) + { + { + std::lock_guard oLock(oMutex); + bThreadCanFinish = true; + oCond.notify_all(); + } + for (auto &oThread : aoThreads) + oThread.join(); + } for (size_t i = 0; i < aoDS.size(); ++i) GDALClose(aoDS[i]); + if (bClose) + GDALClose(hThreadSafeDS); + + if (bDetach) + { + for (auto &oThread : aoThreads) + oThread.detach(); + } + else if (bJoinAfterClosing) + { + { + std::lock_guard oLock(oMutex); + bThreadCanFinish = true; + oCond.notify_all(); + } + for (auto &oThread : aoThreads) + oThread.join(); + } printf("All threads complete.\n"); @@ -250,5 +321,13 @@ int main(int argc, char **argv) GDALDestroyDriverManager(); + { + std::lock_guard oLock(oMutex); + bThreadCanFinish = true; + oCond.notify_all(); + } + + printf("End of main.\n"); + return 0; } diff --git a/autotest/gcore/thread_test.py b/autotest/gcore/thread_test.py index 9fe563757aae..8199dd900048 100755 --- a/autotest/gcore/thread_test.py +++ b/autotest/gcore/thread_test.py @@ -75,3 +75,44 @@ def test_thread_test_1(): ret = False assert ret + + +def launch_threads(ds, expected_cs): + def verify_checksum(): + for i in range(1000): + assert ds.GetRasterBand(1).Checksum() == expected_cs + + threads = [threading.Thread(target=verify_checksum)] + for t in threads: + t.start() + for t in threads: + t.join() + + +def test_thread_safe_open(): + + ds = gdal.OpenEx("data/byte.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE) + assert ds.IsThreadSafe(gdal.OF_RASTER) + assert not ds.IsThreadSafe(gdal.OF_RASTER | gdal.OF_UPDATE) + launch_threads(ds, 4672) + + +def test_thread_safe_create(): + + ds = gdal.OpenEx("data/byte.tif", gdal.OF_RASTER) + assert not ds.IsThreadSafe(gdal.OF_RASTER) + assert ds.GetRefCount() == 1 + thread_safe_ds = ds.CreateThreadSafeDataset(gdal.OF_RASTER) + assert thread_safe_ds.IsThreadSafe(gdal.OF_RASTER) + assert ds.GetRefCount() == 2 + del ds + launch_threads(thread_safe_ds, 4672) + + +def test_thread_safe_create_close_src_ds(): + + ds = gdal.OpenEx("data/byte.tif", gdal.OF_RASTER) + thread_safe_ds = ds.CreateThreadSafeDataset(gdal.OF_RASTER) + ds.Close() + with pytest.raises(Exception): + thread_safe_ds.RasterCount diff --git a/doc/source/spelling_wordlist.txt b/doc/source/spelling_wordlist.txt index 46bea3a8339d..95170ba53d6c 100644 --- a/doc/source/spelling_wordlist.txt +++ b/doc/source/spelling_wordlist.txt @@ -2047,6 +2047,7 @@ nRecurseDepth nRefCount nReqOrder nSampleStep +nScopeFlags nSecond nSize nSrcBufferAllocSize diff --git a/frmts/mem/memdataset.cpp b/frmts/mem/memdataset.cpp index 2d624dafeab4..493da955f996 100644 --- a/frmts/mem/memdataset.cpp +++ b/frmts/mem/memdataset.cpp @@ -89,10 +89,10 @@ GDALRasterBandH MEMCreateRasterBandEx(GDALDataset *poDS, int nBand, /************************************************************************/ MEMRasterBand::MEMRasterBand(GByte *pabyDataIn, GDALDataType eTypeIn, - int nXSizeIn, int nYSizeIn) + int nXSizeIn, int nYSizeIn, bool bOwnDataIn) : GDALPamRasterBand(FALSE), pabyData(pabyDataIn), nPixelOffset(GDALGetDataTypeSizeBytes(eTypeIn)), nLineOffset(0), - bOwnData(true) + bOwnData(bOwnDataIn) { eAccess = GA_Update; eDataType = eTypeIn; @@ -463,7 +463,7 @@ int MEMRasterBand::GetOverviewCount() MEMDataset *poMemDS = dynamic_cast(poDS); if (poMemDS == nullptr) return 0; - return poMemDS->m_nOverviewDSCount; + return static_cast(poMemDS->m_apoOverviewDS.size()); } /************************************************************************/ @@ -476,9 +476,9 @@ GDALRasterBand *MEMRasterBand::GetOverview(int i) MEMDataset *poMemDS = dynamic_cast(poDS); if (poMemDS == nullptr) return nullptr; - if (i < 0 || i >= poMemDS->m_nOverviewDSCount) + if (i < 0 || i >= static_cast(poMemDS->m_apoOverviewDS.size())) return nullptr; - return poMemDS->m_papoOverviewDS[i]->GetRasterBand(nBand); + return poMemDS->m_apoOverviewDS[i]->GetRasterBand(nBand); } /************************************************************************/ @@ -504,8 +504,8 @@ CPLErr MEMRasterBand::CreateMaskBand(int nFlagsIn) return CE_Failure; nMaskFlags = nFlagsIn; - auto poMemMaskBand = - new MEMRasterBand(pabyMaskData, GDT_Byte, nRasterXSize, nRasterYSize); + auto poMemMaskBand = new MEMRasterBand(pabyMaskData, GDT_Byte, nRasterXSize, + nRasterYSize, /* bOwnData= */ true); poMemMaskBand->m_bIsMask = true; poMask.reset(poMemMaskBand, true); if ((nFlagsIn & GMF_PER_DATASET) != 0 && nBand == 1 && poMemDS != nullptr) @@ -542,8 +542,7 @@ bool MEMRasterBand::IsMaskBand() const /************************************************************************/ MEMDataset::MEMDataset() - : GDALDataset(FALSE), bGeoTransformSet(FALSE), m_nOverviewDSCount(0), - m_papoOverviewDS(nullptr), m_poPrivate(new Private()) + : GDALDataset(FALSE), bGeoTransformSet(FALSE), m_poPrivate(new Private()) { adfGeoTransform[0] = 0.0; adfGeoTransform[1] = 1.0; @@ -565,10 +564,6 @@ MEMDataset::~MEMDataset() bSuppressOnClose = true; FlushCache(true); bSuppressOnClose = bSuppressOnCloseBackup; - - for (int i = 0; i < m_nOverviewDSCount; ++i) - delete m_papoOverviewDS[i]; - CPLFree(m_papoOverviewDS); } #if 0 @@ -824,11 +819,7 @@ CPLErr MEMDataset::IBuildOverviews(const char *pszResampling, int nOverviews, if (nOverviews == 0) { // Cleanup existing overviews - for (int i = 0; i < m_nOverviewDSCount; ++i) - delete m_papoOverviewDS[i]; - CPLFree(m_papoOverviewDS); - m_nOverviewDSCount = 0; - m_papoOverviewDS = nullptr; + m_apoOverviewDS.clear(); return CE_None; } @@ -901,7 +892,7 @@ CPLErr MEMDataset::IBuildOverviews(const char *pszResampling, int nOverviews, // Create new overview dataset if needed. if (!bExisting) { - MEMDataset *poOvrDS = new MEMDataset(); + auto poOvrDS = std::make_unique(); poOvrDS->eAccess = GA_Update; poOvrDS->nRasterXSize = (nRasterXSize + panOverviewList[i] - 1) / panOverviewList[i]; @@ -913,14 +904,10 @@ CPLErr MEMDataset::IBuildOverviews(const char *pszResampling, int nOverviews, GetRasterBand(iBand + 1)->GetRasterDataType(); if (poOvrDS->AddBand(eDT, nullptr) != CE_None) { - delete poOvrDS; return CE_Failure; } } - m_nOverviewDSCount++; - m_papoOverviewDS = (GDALDataset **)CPLRealloc( - m_papoOverviewDS, sizeof(GDALDataset *) * m_nOverviewDSCount); - m_papoOverviewDS[m_nOverviewDSCount - 1] = poOvrDS; + m_apoOverviewDS.emplace_back(std::move(poOvrDS)); } } @@ -1053,6 +1040,71 @@ CPLErr MEMDataset::CreateMaskBand(int nFlagsIn) return poFirstBand->CreateMaskBand(nFlagsIn | GMF_PER_DATASET); } +/************************************************************************/ +/* CanBeCloned() */ +/************************************************************************/ + +bool MEMDataset::CanBeCloned(int nScopeFlags, bool bCanShareState) const +{ + return nScopeFlags == GDAL_OF_RASTER && bCanShareState && + typeid(this) == typeid(const MEMDataset *); +} + +/************************************************************************/ +/* Clone() */ +/************************************************************************/ + +std::unique_ptr MEMDataset::Clone(int nScopeFlags, + bool bCanShareState) const +{ + if (MEMDataset::CanBeCloned(nScopeFlags, bCanShareState)) + { + auto poNewDS = std::make_unique(); + poNewDS->nRasterXSize = nRasterXSize; + poNewDS->nRasterYSize = nRasterYSize; + poNewDS->bGeoTransformSet = bGeoTransformSet; + memcpy(poNewDS->adfGeoTransform, adfGeoTransform, + sizeof(adfGeoTransform)); + poNewDS->m_oSRS = m_oSRS; + poNewDS->m_aoGCPs = m_aoGCPs; + poNewDS->m_oGCPSRS = m_oGCPSRS; + for (const auto &poOvrDS : m_apoOverviewDS) + { + poNewDS->m_apoOverviewDS.emplace_back( + poOvrDS->Clone(nScopeFlags, bCanShareState)); + } + + for (int i = 1; i <= nBands; ++i) + { + auto poSrcMEMBand = + dynamic_cast(papoBands[i - 1]); + CPLAssert(poSrcMEMBand); + auto poNewBand = std::make_unique( + poNewDS.get(), i, poSrcMEMBand->pabyData, + poSrcMEMBand->GetRasterDataType(), poSrcMEMBand->nPixelOffset, + poSrcMEMBand->nLineOffset, + /* bAssumeOwnership = */ false); + + auto poSrcMaskBand = + dynamic_cast(poSrcMEMBand->poMask.get()); + if (poSrcMaskBand) + { + auto poMaskBand = new MEMRasterBand( + poSrcMaskBand->pabyData, GDT_Byte, nRasterXSize, + nRasterYSize, /* bOwnData = */ false); + poMaskBand->m_bIsMask = true; + poNewBand->poMask.reset(poMaskBand, true); + poNewBand->nMaskFlags = poSrcMaskBand->nMaskFlags; + } + + poNewDS->SetBand(i, std::move(poNewBand)); + } + + return poNewDS; + } + return GDALDataset::Clone(nScopeFlags, bCanShareState); +} + /************************************************************************/ /* Open() */ /************************************************************************/ diff --git a/frmts/mem/memdataset.h b/frmts/mem/memdataset.h index f2843cbf78dd..8a52c0a69f9d 100644 --- a/frmts/mem/memdataset.h +++ b/frmts/mem/memdataset.h @@ -66,8 +66,7 @@ class CPL_DLL MEMDataset CPL_NON_FINAL : public GDALDataset std::vector m_aoGCPs{}; OGRSpatialReference m_oGCPSRS{}; - int m_nOverviewDSCount; - GDALDataset **m_papoOverviewDS; + std::vector> m_apoOverviewDS{}; struct Private; std::unique_ptr m_poPrivate; @@ -85,6 +84,12 @@ class CPL_DLL MEMDataset CPL_NON_FINAL : public GDALDataset int nYSize, int nBands, GDALDataType eType, char **papszParamList); + protected: + bool CanBeCloned(int nScopeFlags, bool bCanShareState) const override; + + std::unique_ptr Clone(int nScopeFlags, + bool bCanShareState) const override; + public: MEMDataset(); virtual ~MEMDataset(); @@ -141,9 +146,6 @@ class CPL_DLL MEMDataset CPL_NON_FINAL : public GDALDataset class CPL_DLL MEMRasterBand CPL_NON_FINAL : public GDALPamRasterBand { private: - MEMRasterBand(GByte *pabyDataIn, GDALDataType eTypeIn, int nXSizeIn, - int nYSizeIn); - CPL_DISALLOW_COPY_ASSIGN(MEMRasterBand) protected: @@ -156,6 +158,9 @@ class CPL_DLL MEMRasterBand CPL_NON_FINAL : public GDALPamRasterBand bool m_bIsMask = false; + MEMRasterBand(GByte *pabyDataIn, GDALDataType eTypeIn, int nXSizeIn, + int nYSizeIn, bool bOwnDataIn); + public: MEMRasterBand(GDALDataset *poDS, int nBand, GByte *pabyData, GDALDataType eType, GSpacing nPixelOffset, diff --git a/gcore/CMakeLists.txt b/gcore/CMakeLists.txt index 7253fd9cea9e..a629c7f163b2 100644 --- a/gcore/CMakeLists.txt +++ b/gcore/CMakeLists.txt @@ -40,6 +40,7 @@ add_library( gdalrelationship.cpp gdalsubdatasetinfo.cpp gdalorienteddataset.cpp + gdalthreadsafedataset.cpp overview.cpp rasterio.cpp rawdataset.cpp diff --git a/gcore/gdal.h b/gcore/gdal.h index 21d1a4da18d7..2ca44f0cc0cd 100644 --- a/gcore/gdal.h +++ b/gcore/gdal.h @@ -1030,6 +1030,14 @@ GDALDatasetH CPL_DLL CPL_STDCALL GDALOpenShared(const char *, GDALAccess) #define GDAL_OF_FROM_GDALOPEN 0x400 #endif +/** Open in thread-safe mode. Not compatible with + * GDAL_OF_VECTOR, GDAL_OF_MULTIDIM_RASTER or GDAL_OF_UPDATE + * + * Used by GDALOpenEx(). + * @since GDAL 3.10 + */ +#define GDAL_OF_THREAD_SAFE 0x800 + GDALDatasetH CPL_DLL CPL_STDCALL GDALOpenEx( const char *pszFilename, unsigned int nOpenFlags, const char *const *papszAllowedDrivers, const char *const *papszOpenOptions, @@ -1140,6 +1148,11 @@ int CPL_DLL CPL_STDCALL GDALGetRasterYSize(GDALDatasetH); int CPL_DLL CPL_STDCALL GDALGetRasterCount(GDALDatasetH); GDALRasterBandH CPL_DLL CPL_STDCALL GDALGetRasterBand(GDALDatasetH, int); +bool CPL_DLL GDALDatasetIsThreadSafe(GDALDatasetH, int nScopeFlags, + CSLConstList papszOptions); +GDALDatasetH CPL_DLL GDALCreateThreadSafeDataset(GDALDatasetH, int nScopeFlags, + CSLConstList papszOptions); + CPLErr CPL_DLL CPL_STDCALL GDALAddBand(GDALDatasetH hDS, GDALDataType eType, CSLConstList papszOptions); diff --git a/gcore/gdal_priv.h b/gcore/gdal_priv.h index a801d5a94891..6c7313ae4820 100644 --- a/gcore/gdal_priv.h +++ b/gcore/gdal_priv.h @@ -620,6 +620,15 @@ class CPL_DLL GDALDataset : public GDALMajorObject void ShareLockWithParentDataset(GDALDataset *poParentDataset); + bool m_bCanBeReopened = false; + + virtual bool CanBeCloned(int nScopeFlags, bool bCanShareState) const; + + friend class GDALThreadSafeDataset; + friend class MEMDataset; + virtual std::unique_ptr Clone(int nScopeFlags, + bool bCanShareState) const; + //! @endcond void CleanupPostFileClosing(); @@ -831,7 +840,7 @@ class CPL_DLL GDALDataset : public GDALMajorObject /** Return MarkSuppressOnClose flag. * @return MarkSuppressOnClose flag. */ - bool IsMarkedSuppressOnClose() + bool IsMarkedSuppressOnClose() const { return bSuppressOnClose; } @@ -844,6 +853,8 @@ class CPL_DLL GDALDataset : public GDALMajorObject return papszOpenOptions; } + bool IsThreadSafe(int nScopeFlags) const; + #ifndef DOXYGEN_SKIP /** Return open options. * @return open options. @@ -4480,6 +4491,12 @@ void CPL_DLL GDALCopyRasterIOExtraArg(GDALRasterIOExtraArg *psDestArg, CPL_C_END +std::unique_ptr CPL_DLL +GDALCreateThreadSafeDataset(std::unique_ptr poDS, int nScopeFlags); + +std::unique_ptr + CPL_DLL GDALCreateThreadSafeDataset(GDALDataset *poDS, int nScopeFlags); + void GDALNullifyOpenDatasetsList(); CPLMutex **GDALGetphDMMutex(); CPLMutex **GDALGetphDLMutex(); diff --git a/gcore/gdaldataset.cpp b/gcore/gdaldataset.cpp index a2ec99616591..d4e733ed2a62 100644 --- a/gcore/gdaldataset.cpp +++ b/gcore/gdaldataset.cpp @@ -3575,6 +3575,28 @@ GDALDatasetH CPL_STDCALL GDALOpenEx(const char *pszFilename, { VALIDATE_POINTER1(pszFilename, "GDALOpen", nullptr); + if ((nOpenFlags & GDAL_OF_THREAD_SAFE) != 0) + { + if ((nOpenFlags & GDAL_OF_UPDATE) != 0) + { + CPLError(CE_Failure, CPLE_IllegalArg, + "GDAL_OF_THREAD_SAFE and GDAL_OF_UPDATE are exclusive"); + return nullptr; + } + if ((nOpenFlags & GDAL_OF_VECTOR) != 0) + { + CPLError(CE_Failure, CPLE_IllegalArg, + "GDAL_OF_THREAD_SAFE and GDAL_OF_VECTOR are exclusive"); + return nullptr; + } + if ((nOpenFlags & GDAL_OF_MULTIDIM_RASTER) != 0) + { + CPLError(CE_Failure, CPLE_IllegalArg, + "GDAL_OF_THREAD_SAFE and GDAL_OF_VECTOR are exclusive"); + return nullptr; + } + } + // If no driver kind is specified, assume all are to be probed. if ((nOpenFlags & GDAL_OF_KIND_MASK) == 0) nOpenFlags |= GDAL_OF_KIND_MASK & ~GDAL_OF_MULTIDIM_RASTER; @@ -3866,7 +3888,8 @@ GDALDatasetH CPL_STDCALL GDALOpenEx(const char *pszFilename, } else { - if (!(nOpenFlags & GDAL_OF_INTERNAL)) + if (!(nOpenFlags & GDAL_OF_INTERNAL) && + !(nOpenFlags & GDAL_OF_THREAD_SAFE)) { poDS->AddToDatasetOpenList(); } @@ -3875,7 +3898,8 @@ GDALDatasetH CPL_STDCALL GDALOpenEx(const char *pszFilename, CSLDestroy(poDS->papszOpenOptions); poDS->papszOpenOptions = CSLDuplicate(papszOpenOptions); poDS->nOpenFlags = nOpenFlags; - poDS->MarkAsShared(); + if (!(nOpenFlags & GDAL_OF_THREAD_SAFE)) + poDS->MarkAsShared(); } } } @@ -3889,7 +3913,7 @@ GDALDatasetH CPL_STDCALL GDALOpenEx(const char *pszFilename, "and description (%s)", pszFilename, poDS->GetDescription()); } - else + else if (!(nOpenFlags & GDAL_OF_THREAD_SAFE)) { poDS->MarkAsShared(); } @@ -3908,6 +3932,29 @@ GDALDatasetH CPL_STDCALL GDALOpenEx(const char *pszFilename, } #endif + if (poDS) + { + poDS->m_bCanBeReopened = true; + + if ((nOpenFlags & GDAL_OF_THREAD_SAFE) != 0) + { + poDS = + GDALCreateThreadSafeDataset( + std::unique_ptr(poDS), GDAL_OF_RASTER) + .release(); + if (poDS) + { + poDS->m_bCanBeReopened = true; + poDS->poDriver = poDriver; + poDS->nOpenFlags = nOpenFlags; + if (!(nOpenFlags & GDAL_OF_INTERNAL)) + poDS->AddToDatasetOpenList(); + if (nOpenFlags & GDAL_OF_SHARED) + poDS->MarkAsShared(); + } + } + } + return poDS; } @@ -8174,7 +8221,8 @@ bool GDALDataset::SetQueryLoggerFunc(CPL_UNUSED GDALQueryLoggerFunc callback, int GDALDataset::EnterReadWrite(GDALRWFlag eRWFlag) { - if (m_poPrivate == nullptr) + if (m_poPrivate == nullptr || + IsThreadSafe(GDAL_OF_RASTER | (nOpenFlags & GDAL_OF_UPDATE))) return FALSE; if (m_poPrivate->poParentDataset) @@ -10130,3 +10178,34 @@ CPLErr GDALDatasetReadCompressedData(GDALDatasetH hDS, const char *pszFormat, pszFormat, nXOff, nYOff, nXSize, nYSize, nBandCount, panBandList, ppBuffer, pnBufferSize, ppszDetailedFormat); } + +/************************************************************************/ +/* CanBeCloned() */ +/************************************************************************/ + +//! @cond Doxygen_Suppress +bool GDALDataset::CanBeCloned(int nScopeFlags, bool /* bCanShareState */) const +{ + return m_bCanBeReopened && nScopeFlags == GDAL_OF_RASTER; +} + +//! @endcond + +/************************************************************************/ +/* Clone() */ +/************************************************************************/ + +//! @cond Doxygen_Suppress +std::unique_ptr GDALDataset::Clone(int nScopeFlags, + bool /* bCanShareState */) const +{ + CPLStringList aosAllowedDrivers; + if (poDriver) + aosAllowedDrivers.AddString(poDriver->GetDescription()); + return std::unique_ptr(GDALDataset::Open( + GetDescription(), + nScopeFlags | GDAL_OF_INTERNAL | GDAL_OF_VERBOSE_ERROR, + aosAllowedDrivers.List(), papszOpenOptions)); +} + +//! @endcond diff --git a/gcore/gdalthreadsafedataset.cpp b/gcore/gdalthreadsafedataset.cpp new file mode 100644 index 000000000000..4d9b2f90b5a3 --- /dev/null +++ b/gcore/gdalthreadsafedataset.cpp @@ -0,0 +1,752 @@ +/****************************************************************************** + * + * Project: GDAL Core + * Purpose: Base class for thread safe dataset + * Author: Even Rouault + * + ****************************************************************************** + * Copyright (c) 2024, Even Rouault + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + ****************************************************************************/ + +#ifndef DOXYGEN_SKIP + +#include "cpl_mem_cache.h" +#include "gdal_proxy.h" +#include "gdal_rat.h" +#include "gdal_priv.h" + +#include +#include +#include +#include +#include + +/************************************************************************/ +/* GDALThreadLocalDatasetCache */ +/************************************************************************/ + +class GDALThreadSafeDataset; + +class GDALThreadLocalDatasetCache +{ + private: + std::unique_ptr>> + m_poCache{}; + + GDALThreadLocalDatasetCache(const GDALThreadLocalDatasetCache &) = delete; + GDALThreadLocalDatasetCache & + operator=(const GDALThreadLocalDatasetCache &) = delete; + + public: + GIntBig m_nThreadID = 0; + std::mutex m_oMutex{}; + lru11::Cache> + &m_oCache; + std::map, CPLStringList>> + m_oReferencedDS{}; + std::map m_oReferencedDSFromBand{}; + + GDALThreadLocalDatasetCache(); + ~GDALThreadLocalDatasetCache(); +}; + +/************************************************************************/ +/* GDALThreadSafeDataset */ +/************************************************************************/ + +class GDALThreadSafeDataset final : public GDALProxyDataset +{ + public: + GDALThreadSafeDataset(std::unique_ptr poPrototypeDSUniquePtr, + GDALDataset *poPrototypeDS); + ~GDALThreadSafeDataset() override; + + static std::unique_ptr + Create(std::unique_ptr poPrototypeDSUniquePtr, + GDALDataset *poPrototypeDS, int nScopeFlags); + + GDALDataset *RefUnderlyingDataset() const override; + void + UnrefUnderlyingDataset(GDALDataset *poUnderlyingDataset) const override; + + int CloseDependentDatasets() override; + + private: + friend class GDALThreadSafeRasterBand; + friend class GDALThreadLocalDatasetCache; + + CPLStringList m_aosThreadLocalConfigOptions{}; + + std::mutex m_oPrototypeDSMutex{}; + GDALDataset *m_poPrototypeDS = nullptr; + std::unique_ptr m_poPrototypeDSUniquePtr{}; + + // Static variables + static std::mutex g_Mutex; + static std::set g_oSetOfCache; + + // Thread local variables + static thread_local std::unique_ptr tl_poCache; + + void UnrefUnderlyingDataset(GDALDataset *poUnderlyingDataset, + GDALThreadLocalDatasetCache *poCache) const; + + GDALThreadSafeDataset(const GDALThreadSafeDataset &) = delete; + GDALThreadSafeDataset &operator=(const GDALThreadSafeDataset &) = delete; +}; + +/************************************************************************/ +/* GDALThreadSafeRasterBand */ +/************************************************************************/ + +class GDALThreadSafeRasterBand final : public GDALProxyRasterBand +{ + public: + GDALThreadSafeRasterBand(GDALThreadSafeDataset *poTSDS, + GDALDataset *poParentDS, int nBandIn, + GDALRasterBand *poPrototypeBand, + int nBaseBandOfMaskBand, int nOvrIdx); + + GDALRasterBand *RefUnderlyingRasterBand(bool bForceOpen) const override; + void UnrefUnderlyingRasterBand( + GDALRasterBand *poUnderlyingRasterBand) const override; + + GDALRasterBand *GetMaskBand() override; + int GetOverviewCount() override; + GDALRasterBand *GetOverview(int idx) override; + GDALRasterBand *GetRasterSampleOverview(GUIntBig nDesiredSamples) override; + + GDALRasterAttributeTable *GetDefaultRAT() override; + + private: + GDALThreadSafeRasterBand(const GDALThreadSafeRasterBand &) = delete; + GDALThreadSafeRasterBand & + operator=(const GDALThreadSafeRasterBand &) = delete; + + GDALThreadSafeDataset *m_poTSDS = nullptr; + GDALRasterBand *m_poPrototypeBand = nullptr; + const int m_nBaseBandOfMaskBand; + const int m_nOvrIdx; + std::unique_ptr m_poMaskBand{}; + std::vector> m_apoOverviews{}; +}; + +/************************************************************************/ +/* Global variables initialization. */ +/************************************************************************/ + +/* static */ std::mutex GDALThreadSafeDataset::g_Mutex; +/* static */ std::set + GDALThreadSafeDataset::g_oSetOfCache; + +thread_local std::unique_ptr + GDALThreadSafeDataset::tl_poCache; + +/************************************************************************/ +/* GDALThreadLocalDatasetCache() */ +/************************************************************************/ + +GDALThreadLocalDatasetCache::GDALThreadLocalDatasetCache() + : m_poCache(std::make_unique>>()), + m_nThreadID(CPLGetPID()), m_oCache(*m_poCache.get()) +{ + CPLDebug("GDAL", + "Registering thread-safe dataset cache for thread " CPL_FRMT_GIB, + m_nThreadID); + std::lock_guard oLock(GDALThreadSafeDataset::g_Mutex); + GDALThreadSafeDataset::g_oSetOfCache.insert(this); +} + +/************************************************************************/ +/* ~GDALThreadLocalDatasetCache() */ +/************************************************************************/ + +GDALThreadLocalDatasetCache::~GDALThreadLocalDatasetCache() +{ + const bool bDriverManagerDestroyed = *GDALGetphDMMutex() == nullptr; + if (bDriverManagerDestroyed) + { + // Leak datasets when GDAL has been de-initialized + if (!m_poCache->empty()) + CPL_IGNORE_RET_VAL(m_poCache.release()); + return; + } + + CPLDebug("GDAL", + "Unregistering thread-safe dataset cache for thread " CPL_FRMT_GIB, + m_nThreadID); + { + std::lock_guard oLock(GDALThreadSafeDataset::g_Mutex); + GDALThreadSafeDataset::g_oSetOfCache.erase(this); + } + const auto lambda = + [this](const lru11::KeyValuePair> &kv) + { + CPLDebug("GDAL", + "~GDALThreadLocalDatasetCache(): GDALClose(%s, this=%p) " + "for thread " CPL_FRMT_GIB, + kv.value->GetDescription(), kv.value.get(), m_nThreadID); + }; + m_oCache.cwalk(lambda); +} + +/************************************************************************/ +/* GDALThreadSafeDataset() */ +/************************************************************************/ + +GDALThreadSafeDataset::GDALThreadSafeDataset( + std::unique_ptr poPrototypeDSUniquePtr, + GDALDataset *poPrototypeDS) + : m_poPrototypeDS(poPrototypeDS) +{ + nRasterXSize = poPrototypeDS->GetRasterXSize(); + nRasterYSize = poPrototypeDS->GetRasterYSize(); + for (int i = 1; i <= poPrototypeDS->GetRasterCount(); ++i) + { + SetBand(i, std::make_unique( + this, this, i, poPrototypeDS->GetRasterBand(i), 0, -1)); + } + nOpenFlags = GDAL_OF_RASTER | GDAL_OF_THREAD_SAFE; + SetDescription(poPrototypeDS->GetDescription()); + papszOpenOptions = CSLDuplicate(poPrototypeDS->GetOpenOptions()); + m_aosThreadLocalConfigOptions = CPLGetThreadLocalConfigOptions(); + m_poPrototypeDSUniquePtr = std::move(poPrototypeDSUniquePtr); + + if (!m_poPrototypeDSUniquePtr) + m_poPrototypeDS->Reference(); +} + +/************************************************************************/ +/* Create() */ +/************************************************************************/ + +/* static */ std::unique_ptr +GDALThreadSafeDataset::Create( + std::unique_ptr poPrototypeDSUniquePtr, + GDALDataset *poPrototypeDS, int nScopeFlags) +{ + if (nScopeFlags != GDAL_OF_RASTER) + { + CPLError(CE_Failure, CPLE_NotSupported, + "GDALCreateThreadSafeDataset(): Only nScopeFlags == " + "GDAL_OF_RASTER is supported"); + return nullptr; + } + if (!poPrototypeDS->CanBeCloned(nScopeFlags, /* bCanShareState = */ true)) + { + CPLError(CE_Failure, CPLE_NotSupported, + "GDALCreateThreadSafeDataset(): Source dataset cannot be " + "cloned"); + return nullptr; + } + return std::make_unique( + std::move(poPrototypeDSUniquePtr), poPrototypeDS); +} + +/************************************************************************/ +/* ~GDALThreadSafeDataset() */ +/************************************************************************/ + +GDALThreadSafeDataset::~GDALThreadSafeDataset() +{ + if (m_poPrototypeDS) + { + GDALThreadSafeDataset::FlushCache(true); + } + + // Collect TLS datasets in a vector, and free them after releasing + // g_nInDestructorCounter to limit contention + std::vector, GIntBig>> aoDSToFree; + { + std::lock_guard oLock(g_Mutex); + for (auto *poCache : g_oSetOfCache) + { + std::unique_lock oLockCache(poCache->m_oMutex); + std::shared_ptr poDS; + if (poCache->m_oCache.tryGet(this, poDS)) + { + aoDSToFree.emplace_back(std::move(poDS), poCache->m_nThreadID); + poCache->m_oCache.remove(this); + } + } + } + + for (const auto &oEntry : aoDSToFree) + { + CPLDebug("GDAL", + "~GDALThreadSafeDataset(): GDALClose(%s, this=%p) for " + "thread " CPL_FRMT_GIB, + GetDescription(), oEntry.first.get(), oEntry.second); + } + // Actually release TLS datasets + aoDSToFree.clear(); + + GDALThreadSafeDataset::CloseDependentDatasets(); +} + +/************************************************************************/ +/* CloseDependentDatasets() */ +/************************************************************************/ + +int GDALThreadSafeDataset::CloseDependentDatasets() +{ + int bRet = false; + if (m_poPrototypeDSUniquePtr) + { + bRet = true; + } + else if (m_poPrototypeDS) + { + if (m_poPrototypeDS->ReleaseRef()) + { + bRet = true; + } + } + + m_poPrototypeDSUniquePtr.reset(); + m_poPrototypeDS = nullptr; + + return bRet; +} + +/************************************************************************/ +/* RefUnderlyingDataset() */ +/************************************************************************/ + +GDALDataset *GDALThreadSafeDataset::RefUnderlyingDataset() const +{ + CPLStringList aosTLConfigOptionsBackup(CPLGetThreadLocalConfigOptions()); + const CPLStringList aosMerged( + CSLMerge(CSLDuplicate(m_aosThreadLocalConfigOptions.List()), + aosTLConfigOptionsBackup.List())); + CPLSetThreadLocalConfigOptions(aosMerged.List()); + + std::shared_ptr poDS; + + GDALThreadLocalDatasetCache *poCache = tl_poCache.get(); + if (!poCache) + { + auto poCacheUniquePtr = std::make_unique(); + poCache = poCacheUniquePtr.get(); + tl_poCache = std::move(poCacheUniquePtr); + } + std::unique_lock oLock(poCache->m_oMutex); + if (poCache->m_oCache.tryGet(this, poDS)) + { + CPLAssert(!cpl::contains(poCache->m_oReferencedDS, this)); + auto poDSRet = poDS.get(); + poCache->m_oReferencedDS[this] = std::make_pair( + std::move(poDS), std::move(aosTLConfigOptionsBackup)); + return poDSRet; + } + oLock.unlock(); + poDS = m_poPrototypeDS->Clone(GDAL_OF_RASTER, /* bCanShareState=*/true); + if (poDS) + { + CPLDebug("GDAL", "GDALOpen(%s, this=%p) for thread " CPL_FRMT_GIB, + GetDescription(), poDS.get(), CPLGetPID()); + } + oLock.lock(); + if (!poDS) + { + CPLSetThreadLocalConfigOptions(aosTLConfigOptionsBackup.List()); + return nullptr; + } + + auto poDSRet = poDS.get(); + { + poCache->m_oCache.insert(this, poDS); + CPLAssert(!cpl::contains(poCache->m_oReferencedDS, this)); + poCache->m_oReferencedDS[this] = std::make_pair( + std::move(poDS), std::move(aosTLConfigOptionsBackup)); + } + return poDSRet; +} + +/************************************************************************/ +/* UnrefUnderlyingDataset() */ +/************************************************************************/ + +void GDALThreadSafeDataset::UnrefUnderlyingDataset( + GDALDataset *poUnderlyingDataset) const +{ + GDALThreadLocalDatasetCache *poCache = tl_poCache.get(); + CPLAssert(poCache); + std::unique_lock oLock(poCache->m_oMutex); + UnrefUnderlyingDataset(poUnderlyingDataset, poCache); +} + +/************************************************************************/ +/* UnrefUnderlyingDataset() */ +/************************************************************************/ + +void GDALThreadSafeDataset::UnrefUnderlyingDataset( + [[maybe_unused]] GDALDataset *poUnderlyingDataset, + GDALThreadLocalDatasetCache *poCache) const +{ + auto oIter = poCache->m_oReferencedDS.find(this); + CPLAssert(oIter != poCache->m_oReferencedDS.end()); + CPLAssert(oIter->second.first.get() == poUnderlyingDataset); + CPLSetThreadLocalConfigOptions(oIter->second.second.List()); + poCache->m_oReferencedDS.erase(oIter); +} + +/************************************************************************/ +/* GDALThreadSafeRasterBand() */ +/************************************************************************/ + +GDALThreadSafeRasterBand::GDALThreadSafeRasterBand( + GDALThreadSafeDataset *poTSDS, GDALDataset *poParentDS, int nBandIn, + GDALRasterBand *poPrototypeBand, int nBaseBandOfMaskBand, int nOvrIdx) + : m_poTSDS(poTSDS), m_poPrototypeBand(poPrototypeBand), + m_nBaseBandOfMaskBand(nBaseBandOfMaskBand), m_nOvrIdx(nOvrIdx) +{ + poDS = poParentDS; + nBand = nBandIn; + eDataType = poPrototypeBand->GetRasterDataType(); + nRasterXSize = poPrototypeBand->GetXSize(); + nRasterYSize = poPrototypeBand->GetYSize(); + poPrototypeBand->GetBlockSize(&nBlockXSize, &nBlockYSize); + + if (nBandIn > 0) + { + m_poMaskBand = std::make_unique( + poTSDS, nullptr, 0, poPrototypeBand->GetMaskBand(), nBandIn, + nOvrIdx); + if (nOvrIdx < 0) + { + const int nOvrCount = poPrototypeBand->GetOverviewCount(); + for (int iOvrIdx = 0; iOvrIdx < nOvrCount; ++iOvrIdx) + { + m_apoOverviews.emplace_back( + std::make_unique( + poTSDS, nullptr, nBandIn, + poPrototypeBand->GetOverview(iOvrIdx), + nBaseBandOfMaskBand, iOvrIdx)); + } + } + } + else if (nBaseBandOfMaskBand > 0) + { + m_poMaskBand = std::make_unique( + poTSDS, nullptr, 0, poPrototypeBand->GetMaskBand(), + -nBaseBandOfMaskBand, nOvrIdx); + } +} + +/************************************************************************/ +/* RefUnderlyingRasterBand() */ +/************************************************************************/ + +GDALRasterBand * +GDALThreadSafeRasterBand::RefUnderlyingRasterBand(bool /*bForceOpen*/) const +{ + // Get a thread-local dataset + auto poTLDS = m_poTSDS->RefUnderlyingDataset(); + if (!poTLDS) + return nullptr; + + // Get corresponding thread-local band + auto poTLRasterBand = poTLDS->GetRasterBand( + m_nBaseBandOfMaskBand ? std::abs(m_nBaseBandOfMaskBand) : nBand); + if (!poTLRasterBand) + return nullptr; + if (m_nOvrIdx >= 0) + poTLRasterBand = poTLRasterBand->GetOverview(m_nOvrIdx); + if (m_nBaseBandOfMaskBand) + { + poTLRasterBand = poTLRasterBand->GetMaskBand(); + if (m_nBaseBandOfMaskBand < 0) + poTLRasterBand = poTLRasterBand->GetMaskBand(); + } + + // Registers the association between the thread-local band and the + // thread-local dataset + { + GDALThreadLocalDatasetCache *poCache = + GDALThreadSafeDataset::tl_poCache.get(); + CPLAssert(poCache); + std::unique_lock oLock(poCache->m_oMutex); + CPLAssert( + !cpl::contains(poCache->m_oReferencedDSFromBand, poTLRasterBand)); + poCache->m_oReferencedDSFromBand[poTLRasterBand] = poTLDS; + } + // CPLDebug("GDAL", "%p->RefUnderlyingRasterBand() return %p", this, poTLRasterBand); + return poTLRasterBand; +} + +/************************************************************************/ +/* UnrefUnderlyingRasterBand() */ +/************************************************************************/ + +void GDALThreadSafeRasterBand::UnrefUnderlyingRasterBand( + GDALRasterBand *poUnderlyingRasterBand) const +{ + // CPLDebug("GDAL", "%p->UnrefUnderlyingRasterBand(%p)", this, poUnderlyingRasterBand); + + // Unregisters the association between the thread-local band and the + // thread-local dataset + { + GDALThreadLocalDatasetCache *poCache = + GDALThreadSafeDataset::tl_poCache.get(); + CPLAssert(poCache); + std::unique_lock oLock(poCache->m_oMutex); + auto oIter = + poCache->m_oReferencedDSFromBand.find(poUnderlyingRasterBand); + CPLAssert(oIter != poCache->m_oReferencedDSFromBand.end()); + + m_poTSDS->UnrefUnderlyingDataset(oIter->second, poCache); + poCache->m_oReferencedDSFromBand.erase(oIter); + } +} + +/************************************************************************/ +/* GetMaskBand() */ +/************************************************************************/ + +GDALRasterBand *GDALThreadSafeRasterBand::GetMaskBand() +{ + return m_poMaskBand ? m_poMaskBand.get() : this; +} + +/************************************************************************/ +/* GetOverviewCount() */ +/************************************************************************/ + +int GDALThreadSafeRasterBand::GetOverviewCount() +{ + return static_cast(m_apoOverviews.size()); +} + +/************************************************************************/ +/* GetOverview() */ +/************************************************************************/ + +GDALRasterBand *GDALThreadSafeRasterBand::GetOverview(int nIdx) +{ + if (nIdx < 0 || nIdx >= static_cast(m_apoOverviews.size())) + return nullptr; + return m_apoOverviews[nIdx].get(); +} + +/************************************************************************/ +/* GetRasterSampleOverview() */ +/************************************************************************/ + +GDALRasterBand * +GDALThreadSafeRasterBand::GetRasterSampleOverview(GUIntBig nDesiredSamples) + +{ + GDALRasterBand *poBestBand = this; + + double dfBestSamples = GetXSize() * static_cast(GetYSize()); + + for (int iOverview = 0; iOverview < GetOverviewCount(); iOverview++) + { + GDALRasterBand *poOBand = GetOverview(iOverview); + + if (poOBand == nullptr) + continue; + + const double dfOSamples = + poOBand->GetXSize() * static_cast(poOBand->GetYSize()); + + if (dfOSamples < dfBestSamples && dfOSamples > nDesiredSamples) + { + dfBestSamples = dfOSamples; + poBestBand = poOBand; + } + } + + return poBestBand; +} + +/************************************************************************/ +/* GetDefaultRAT() */ +/************************************************************************/ + +GDALRasterAttributeTable *GDALThreadSafeRasterBand::GetDefaultRAT() +{ + std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex); + const auto poRAT = m_poPrototypeBand->GetDefaultRAT(); + if (!poRAT) + return nullptr; + + if (dynamic_cast(poRAT)) + return poRAT; + + CPLError(CE_Failure, CPLE_AppDefined, + "GDALThreadSafeRasterBand::GetDefaultRAT() not supporting a " + "non-GDALDefaultRasterAttributeTable implementation"); + return nullptr; +} + +#endif // DOXYGEN_SKIP + +/************************************************************************/ +/* GDALDataset::IsThreadSafe() */ +/************************************************************************/ + +/** Return whether this dataset, and its related objects (typically raster + * bands), can be called for the intended scope. + * + * Note that in the current implementation, nScopeFlags should be set to + * GDAL_OF_RASTER, as thread-safety is limited to read-only operations and + * excludes operations on vector layers (OGRLayer) or multidimensional API + * (GDALGroup, GDALMDArray, etc.) + * + * This is the same as the C function GDALDatasetIsThreadSafe(). + * + * @since 3.10 + */ +bool GDALDataset::IsThreadSafe(int nScopeFlags) const +{ + return (nOpenFlags & GDAL_OF_THREAD_SAFE) != 0 && + nScopeFlags == GDAL_OF_RASTER && (nOpenFlags & GDAL_OF_RASTER) != 0; +} + +/************************************************************************/ +/* GDALDatasetIsThreadSafe() */ +/************************************************************************/ + +/** Return whether this dataset, and its related objects (typically raster + * bands), can be called for the intended scope. + * + * Note that in the current implementation, nScopeFlags should be set to + * GDAL_OF_RASTER, as thread-safety is limited to read-only operations and + * excludes operations on vector layers (OGRLayer) or multidimensional API + * (GDALGroup, GDALMDArray, etc.) + * + * This is the same as the C++ method GDALDataset::IsThreadSafe(). + * + * @param hDS Source dataset + * @param nScopeFlags Intended scope of use. + * Only GDAL_OF_RASTER is supported currently. + * @param papszOptions Options. None currently. + * + * @since 3.10 + */ +bool GDALDatasetIsThreadSafe(GDALDatasetH hDS, int nScopeFlags, + CSLConstList papszOptions) +{ + VALIDATE_POINTER1(hDS, __func__, false); + + CPL_IGNORE_RET_VAL(papszOptions); + + return GDALDataset::FromHandle(hDS)->IsThreadSafe(nScopeFlags); +} + +/************************************************************************/ +/* GDALCreateThreadSafeDataset() */ +/************************************************************************/ + +/** Return a thread-safe dataset. + * + * Ownership of the passed dataset is transferred to the thread-safe dataset. + * + * @param poDS Source dataset + * @param nScopeFlags Intended scope of use. + * Only GDAL_OF_RASTER is supported currently. + * + * @return a new thread-safe dataset, or nullptr in case of error. + * + * @since 3.10 + */ +std::unique_ptr +GDALCreateThreadSafeDataset(std::unique_ptr poDS, int nScopeFlags) +{ + auto poDSRaw = poDS.get(); + return GDALThreadSafeDataset::Create(std::move(poDS), poDSRaw, nScopeFlags); +} + +/************************************************************************/ +/* GDALCreateThreadSafeDataset() */ +/************************************************************************/ + +/** Return a thread-safe dataset. + * + * The life-time of the passed dataset must be longer than the one of + * the returned thread-safe dataset. + * + * Note that this function does increase the reference count on poDS while + * it is being used, so patterns like the following one are valid: + * \code{.cpp} + * auto poDS = GDALDataset::Open(...); + * auto poThreadSafeDS = GDALCreateThreadSafeDataset(poDS, GDAL_OF_RASTER | GDAL_OF_THREAD_SAFE); + * poDS->ReleaseRef(); + * // ... do something with poThreadSafeDS ... + * poThreadSafeDS.reset(); // optional + * \endcode + * + * @param poDS Source dataset + * @param nScopeFlags Intended scope of use. + * Only GDAL_OF_RASTER is supported currently. + * + * @return a new thread-safe dataset, or nullptr in case of error. + + * @since 3.10 + */ +std::unique_ptr GDALCreateThreadSafeDataset(GDALDataset *poDS, + int nScopeFlags) +{ + return GDALThreadSafeDataset::Create(nullptr, poDS, nScopeFlags); +} + +/************************************************************************/ +/* GDALCreateThreadSafeDataset() */ +/************************************************************************/ + +/** Return a thread-safe dataset. + * + * The life-time of the passed dataset must be longer than the one of + * the returned thread-safe dataset. + * + * Note that this function does increase the reference count on hDS while + * it is being used, so patterns like the following one are valid: + * \code{.cpp} + * hDS = GDALOpenEx(...); + * hThreadSafeDS = GDALCreateThreadSafeDataset(hDS, GDAL_OF_RASTER | GDAL_OF_THREAD_SAFE, NULL); + * GDALReleaseDataset(hDS); + * // ... do something with hThreadSafeDS ... + * GDALReleaseDataset(hThreadSafeDS); + * \endcode + * + * @param hDS Source dataset + * @param nScopeFlags Intended scope of use. + * Only GDAL_OF_RASTER is supported currently. + * @param papszOptions Options. None currently. + * + * @since 3.10 + */ +GDALDatasetH GDALCreateThreadSafeDataset(GDALDatasetH hDS, int nScopeFlags, + CSLConstList papszOptions) +{ + VALIDATE_POINTER1(hDS, __func__, nullptr); + + CPL_IGNORE_RET_VAL(papszOptions); + return GDALDataset::ToHandle( + GDALCreateThreadSafeDataset(GDALDataset::FromHandle(hDS), nScopeFlags) + .release()); +} diff --git a/port/cpl_mem_cache.h b/port/cpl_mem_cache.h index 6fecca4b239f..0e021b762748 100644 --- a/port/cpl_mem_cache.h +++ b/port/cpl_mem_cache.h @@ -86,6 +86,10 @@ template struct KeyValuePair KeyValuePair(const K &k, V &&v) : key(k), value(std::move(v)) { } + + private: + KeyValuePair(const KeyValuePair &) = delete; + KeyValuePair &operator=(const KeyValuePair &) = delete; }; /** diff --git a/swig/include/Dataset.i b/swig/include/Dataset.i index 2f3b2e9b6ab7..7f1abec98aff 100644 --- a/swig/include/Dataset.i +++ b/swig/include/Dataset.i @@ -300,6 +300,17 @@ public: return (GDALRasterBandShadow*) GDALGetRasterBand( self, nBand ); } + bool IsThreadSafe(int nScopeFlags) + { + return GDALDatasetIsThreadSafe(self, nScopeFlags, nullptr); + } + +%newobject CreateThreadSafeDataset; + GDALDatasetShadow* CreateThreadSafeDataset(int nScopeFlags) + { + return GDALCreateThreadSafeDataset(self, nScopeFlags, nullptr); + } + %newobject GetRootGroup; GDALGroupHS* GetRootGroup() { return GDALDatasetGetRootGroup(self); diff --git a/swig/include/gdalconst.i b/swig/include/gdalconst.i index 426b2b7b76e9..fb7da99a25b0 100644 --- a/swig/include/gdalconst.i +++ b/swig/include/gdalconst.i @@ -172,6 +172,7 @@ %constant OF_UPDATE = GDAL_OF_UPDATE; %constant OF_SHARED = GDAL_OF_SHARED; %constant OF_VERBOSE_ERROR = GDAL_OF_VERBOSE_ERROR; +%constant OF_THREAD_SAFE = GDAL_OF_THREAD_SAFE; #if !defined(SWIGCSHARP) && !defined(SWIGJAVA) diff --git a/swig/include/python/gdal_python.i b/swig/include/python/gdal_python.i index f6fc9feeea18..d96b952b7549 100644 --- a/swig/include/python/gdal_python.i +++ b/swig/include/python/gdal_python.i @@ -1614,6 +1614,16 @@ CPLErr ReadRaster1( double xoff, double yoff, double xsize, double ysize, return get(value) %} +%feature("pythonappend") CreateThreadSafeDataset %{ + if val: + val._parent_ds = self + + import weakref + if not hasattr(self, '_child_references'): + self._child_references = weakref.WeakSet() + self._child_references.add(val) +%} + %feature("pythonprepend") Close %{ self._invalidate_children() %}