diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 673bb29590..0ad1f5eb73 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -61,6 +61,7 @@ cmake cmdlet cmdlets cmp +CMSG CNG cnt codepage @@ -187,6 +188,9 @@ getline github githubusercontent globals +HCCE +hcertstore +HCRYPTMSG hfile HGLOBAL HIDECANCEL @@ -207,6 +211,7 @@ hstring html http https +HWND Hyperlink IActivation IApplication @@ -280,6 +285,7 @@ logsql logto LONGLONG LPCGUID +LPCSTR LPVOID mailto MAJORVERSION @@ -352,6 +358,7 @@ nuspec OAuth ODR ofstream +Oid opencode opensource openxmlformats @@ -383,6 +390,7 @@ PGP php PII pipssource +PKCS placeholders png posix @@ -494,6 +502,7 @@ src srwlock sscanf sstream +STATEACTION STATFLAG STATSTG STDAPI @@ -518,6 +527,7 @@ STRINGID STRINGIFY STRINGIZE stringstream +STRSAFE strstr strtoll subcontext @@ -630,6 +640,7 @@ wcout wcsicmp webpage wekyb +WHOLECHAIN wil wildcards WINAPI @@ -641,6 +652,7 @@ winmeta winres winrt winsqlite +WINTRUST wix wmain WNS @@ -655,6 +667,7 @@ wrl WStr wstring wstringstream +WTD www xamarin xaml diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index 37d743d84a..cb15a52032 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -120,6 +120,7 @@ EXEHASH experimentalfeatures fcb fd +fdw fedorapeople fileinuse fintimes @@ -281,7 +282,9 @@ parametermap pathparts pathpaths Patil +pb PCs +pcwsz pdp PEGI pfn diff --git a/src/AppInstallerCLICore/Workflows/MSStoreInstallerHandler.cpp b/src/AppInstallerCLICore/Workflows/MSStoreInstallerHandler.cpp index 99d11224eb..33f067a209 100644 --- a/src/AppInstallerCLICore/Workflows/MSStoreInstallerHandler.cpp +++ b/src/AppInstallerCLICore/Workflows/MSStoreInstallerHandler.cpp @@ -88,6 +88,7 @@ namespace AppInstaller::CLI::Workflow if (result.Status() == GetEntitlementStatus::Succeeded) { context.Reporter.Info() << Resource::String::MSStoreInstallGetEntitlementSuccess << std::endl; + AICLI_LOG(CLI, Info, << "Get entitlement succeeded."); } else if (result.Status() == GetEntitlementStatus::NetworkError) { diff --git a/src/AppInstallerCLIPackage/AppInstallerCLIPackage.wapproj b/src/AppInstallerCLIPackage/AppInstallerCLIPackage.wapproj index 0c49ab12e0..d659efee6b 100644 --- a/src/AppInstallerCLIPackage/AppInstallerCLIPackage.wapproj +++ b/src/AppInstallerCLIPackage/AppInstallerCLIPackage.wapproj @@ -71,7 +71,9 @@ - + + Designer + diff --git a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj index 407bcc0e32..04a985704d 100644 --- a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj +++ b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj @@ -320,9 +320,15 @@ true + + true + true + + true + true diff --git a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters index 6d2df2ae11..092f901c9d 100644 --- a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters +++ b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters @@ -408,9 +408,15 @@ TestData + + TestData + TestData + + TestData + TestData diff --git a/src/AppInstallerCLITests/MsixInfo.cpp b/src/AppInstallerCLITests/MsixInfo.cpp index 84732d34a7..7bcc3dca6d 100644 --- a/src/AppInstallerCLITests/MsixInfo.cpp +++ b/src/AppInstallerCLITests/MsixInfo.cpp @@ -3,6 +3,8 @@ #include "pch.h" #include "TestCommon.h" #include +#include +#include using namespace std::string_literals; using namespace std::string_view_literals; @@ -11,11 +13,12 @@ using namespace AppInstaller; constexpr std::string_view s_MsixFile_1 = "index.1.0.0.0.msix"; constexpr std::string_view s_MsixFile_2 = "index.2.0.0.0.msix"; +constexpr std::string_view s_MsixFileSigned_1 = "index.1.0.0.0.signed.msix"; TEST_CASE("MsixInfo_GetPackageFamilyName", "[msixinfo]") { TestDataFile index(s_MsixFile_1); - Msix::MsixInfo msix(index.GetPath().u8string()); + Msix::MsixInfo msix(index.GetPath()); std::string expectedFullName = "AppInstallerCLITestsFakeIndex_1.0.0.0_neutral__125rzkzqaqjwj"; std::string actualFullName = msix.GetPackageFullName(); @@ -23,39 +26,27 @@ TEST_CASE("MsixInfo_GetPackageFamilyName", "[msixinfo]") REQUIRE(expectedFullName == actualFullName); } -TEST_CASE("MsixInfo_WriteManifestAndCompareToSelf", "[msixinfo]") +TEST_CASE("MsixInfo_CompareToSelf", "[msixinfo]") { TestDataFile index(s_MsixFile_1); - Msix::MsixInfo msix(index.GetPath().u8string()); + Msix::MsixInfo msix(index.GetPath()); - TempFile manifest{ "msixtest_manifest"s, ".xml"s }; - ProgressCallback callback; - - msix.WriteManifestToFile(manifest, callback); - - REQUIRE(!msix.IsNewerThan(manifest)); + REQUIRE(!msix.IsNewerThan(index.GetPath().u8string())); } -TEST_CASE("MsixInfo_WriteManifestAndCompareToOlder", "[msixinfo]") +TEST_CASE("MsixInfo_CompareToOlder", "[msixinfo]") { TestDataFile index1(s_MsixFile_1); - Msix::MsixInfo msix1(index1.GetPath().u8string()); - - TempFile manifest{ "msixtest_manifest"s, ".xml"s }; - ProgressCallback callback; - - msix1.WriteManifestToFile(manifest, callback); - TestDataFile index2(s_MsixFile_2); - Msix::MsixInfo msix2(index2.GetPath().u8string()); + Msix::MsixInfo msix2(index2.GetPath()); - REQUIRE(msix2.IsNewerThan(manifest)); + REQUIRE(msix2.IsNewerThan(index1)); } TEST_CASE("MsixInfo_WriteFile", "[msixinfo]") { TestDataFile index(s_MsixFile_1); - Msix::MsixInfo msix(index.GetPath().u8string()); + Msix::MsixInfo msix(index.GetPath()); TempFile file{ "msixtest_file"s, ".bin"s }; ProgressCallback callback; @@ -64,3 +55,42 @@ TEST_CASE("MsixInfo_WriteFile", "[msixinfo]") REQUIRE(1 == std::filesystem::file_size(file)); } + +TEST_CASE("MsixInfo_ValidateMsixTrustInfo", "[msixinfo]") +{ + if (!Runtime::IsRunningAsAdmin()) + { + WARN("Test requires admin privilege. Skipped."); + return; + } + + TestDataFile notSigned{ s_MsixFile_1 }; + Msix::WriteLockedMsixFile notSignedWriteLocked{ notSigned }; + REQUIRE_FALSE(notSignedWriteLocked.ValidateTrustInfo(false)); + + TestDataFile testSigned{ s_MsixFileSigned_1 }; + Msix::WriteLockedMsixFile testSignedWriteLocked{ testSigned }; + + // Remove the cert if already trusted + bool certExistsBeforeTest = UninstallCertFromSignedPackage(testSigned); + + REQUIRE_FALSE(testSignedWriteLocked.ValidateTrustInfo(false)); + + // Add the cert to trusted + InstallCertFromSignedPackage(testSigned); + + REQUIRE(testSignedWriteLocked.ValidateTrustInfo(false)); + REQUIRE_FALSE(testSignedWriteLocked.ValidateTrustInfo(true)); + + TestCommon::TempFile microsoftSigned{ "testIndex"s, ".msix"s }; + ProgressCallback callback; + Utility::Download("https://cdn.winget.microsoft.com/cache/source.msix", microsoftSigned.GetPath(), Utility::DownloadType::Index, callback); + + Msix::WriteLockedMsixFile microsoftSignedWriteLocked{ microsoftSigned }; + REQUIRE(microsoftSignedWriteLocked.ValidateTrustInfo(true)); + + if (!certExistsBeforeTest) + { + UninstallCertFromSignedPackage(testSigned); + } +} \ No newline at end of file diff --git a/src/AppInstallerCLITests/PreIndexedPackageSource.cpp b/src/AppInstallerCLITests/PreIndexedPackageSource.cpp index 591fe99152..cdffb2945e 100644 --- a/src/AppInstallerCLITests/PreIndexedPackageSource.cpp +++ b/src/AppInstallerCLITests/PreIndexedPackageSource.cpp @@ -23,12 +23,10 @@ namespace fs = std::filesystem; constexpr std::string_view s_RepositorySettings_UserSources = "usersources"sv; -constexpr std::string_view s_MsixFile_1 = "index.1.0.0.0.msix"; -constexpr std::string_view s_MsixFile_2 = "index.2.0.0.0.msix"; -constexpr std::string_view s_Msix_FamilyName = "AppInstallerCLITestsFakeIndex_125rzkzqaqjwj"; -constexpr std::string_view s_AppxManifestFileName = "AppxManifest.xml"sv; +constexpr std::string_view s_MsixFile_1 = "index.1.0.0.0.signed.msix"; +constexpr std::string_view s_MsixFile_2 = "index.2.0.0.0.signed.msix"; +constexpr std::string_view s_Msix_FamilyName = "AppInstallerCLITestsFakeIndex_8wekyb3d8bbwe"; constexpr std::string_view s_IndexMsixName = "source.msix"sv; -constexpr std::string_view s_IndexFileName = "index.db"sv; void CopyIndexFileToDirectory(const fs::path& from, const fs::path& to) { @@ -65,12 +63,20 @@ void CleanSources() TEST_CASE("PIPS_Add", "[pips]") { + if (!Runtime::IsRunningAsAdmin()) + { + WARN("Test requires admin privilege. Skipped."); + return; + } + CleanSources(); TempDirectory dir("pipssource"); TestDataFile index(s_MsixFile_1); CopyIndexFileToDirectory(index, dir); + bool shouldCleanCert = InstallCertFromSignedPackage(index); + SourceDetails details; details.Name = "TestName"; details.Type = AppInstaller::Repository::Microsoft::PreIndexedPackageSourceFactory::Type(); @@ -82,25 +88,33 @@ TEST_CASE("PIPS_Add", "[pips]") fs::path state = GetPathToFileDir(); REQUIRE(fs::exists(state)); - fs::path manifest = state; - manifest /= s_AppxManifestFileName; - REQUIRE(fs::exists(manifest)); - REQUIRE(fs::file_size(manifest) > 0); + fs::path indexMsix = state; + indexMsix /= s_IndexMsixName; + REQUIRE(fs::exists(indexMsix)); + REQUIRE(fs::file_size(indexMsix) > 0); - fs::path indexFile = state; - indexFile /= s_IndexFileName; - REQUIRE(fs::exists(indexFile)); - REQUIRE(fs::file_size(indexFile) > 0); + if (shouldCleanCert) + { + UninstallCertFromSignedPackage(index); + } } TEST_CASE("PIPS_UpdateSameVersion", "[pips]") { + if (!Runtime::IsRunningAsAdmin()) + { + WARN("Test requires admin privilege. Skipped."); + return; + } + CleanSources(); TempDirectory dir("pipssource"); TestDataFile index(s_MsixFile_1); CopyIndexFileToDirectory(index, dir); + bool shouldCleanCert = InstallCertFromSignedPackage(index); + SourceDetails details; details.Name = "TestName"; details.Type = AppInstaller::Repository::Microsoft::PreIndexedPackageSourceFactory::Type(); @@ -117,16 +131,29 @@ TEST_CASE("PIPS_UpdateSameVersion", "[pips]") UpdateSource(details.Name, callback); REQUIRE(!progressCalled); + + if (shouldCleanCert) + { + UninstallCertFromSignedPackage(index); + } } TEST_CASE("PIPS_UpdateNewVersion", "[pips]") { + if (!Runtime::IsRunningAsAdmin()) + { + WARN("Test requires admin privilege. Skipped."); + return; + } + CleanSources(); TempDirectory dir("pipssource"); TestDataFile indexMsix1(s_MsixFile_1); CopyIndexFileToDirectory(indexMsix1, dir); + bool shouldCleanCert = InstallCertFromSignedPackage(indexMsix1); + SourceDetails details; details.Name = "TestName"; details.Type = AppInstaller::Repository::Microsoft::PreIndexedPackageSourceFactory::Type(); @@ -138,13 +165,9 @@ TEST_CASE("PIPS_UpdateNewVersion", "[pips]") fs::path state = GetPathToFileDir(); REQUIRE(fs::exists(state)); - fs::path manifestPath = state; - manifestPath /= s_AppxManifestFileName; - std::string manifestContents1 = GetContents(manifestPath); - - fs::path indexPath = state; - indexPath /= s_IndexFileName; - std::string indexContents1 = GetContents(indexPath); + fs::path indexMsix = state; + indexMsix /= s_IndexMsixName; + std::string indexContents1 = GetContents(indexMsix); TestDataFile indexMsix2(s_MsixFile_2); CopyIndexFileToDirectory(indexMsix2, dir); @@ -155,21 +178,31 @@ TEST_CASE("PIPS_UpdateNewVersion", "[pips]") UpdateSource(details.Name, callback); REQUIRE(progressCalled); - std::string manifestContents2 = GetContents(manifestPath); - REQUIRE(manifestContents1 != manifestContents2); - - std::string indexContents2 = GetContents(indexPath); + std::string indexContents2 = GetContents(indexMsix); REQUIRE(indexContents1 != indexContents2); + + if (shouldCleanCert) + { + UninstallCertFromSignedPackage(indexMsix1); + } } TEST_CASE("PIPS_Remove", "[pips]") { + if (!Runtime::IsRunningAsAdmin()) + { + WARN("Test requires admin privilege. Skipped."); + return; + } + CleanSources(); TempDirectory dir("pipssource"); TestDataFile index(s_MsixFile_1); CopyIndexFileToDirectory(index, dir); + bool shouldCleanCert = InstallCertFromSignedPackage(index); + SourceDetails details; details.Name = "TestName"; details.Type = AppInstaller::Repository::Microsoft::PreIndexedPackageSourceFactory::Type(); @@ -181,14 +214,15 @@ TEST_CASE("PIPS_Remove", "[pips]") fs::path state = GetPathToFileDir(); REQUIRE(fs::exists(state)); - fs::path manifest = state; - manifest /= s_AppxManifestFileName; - REQUIRE(fs::exists(manifest)); - - fs::path indexFile = state; - indexFile /= s_IndexFileName; - REQUIRE(fs::exists(indexFile)); + fs::path indexMsix = state; + indexMsix /= s_IndexMsixName; + REQUIRE(fs::exists(indexMsix)); RemoveSource(details.Name, callback); REQUIRE(!fs::exists(state)); + + if (shouldCleanCert) + { + UninstallCertFromSignedPackage(index); + } } diff --git a/src/AppInstallerCLITests/TestCommon.cpp b/src/AppInstallerCLITests/TestCommon.cpp index 67c0b1e074..b060300047 100644 --- a/src/AppInstallerCLITests/TestCommon.cpp +++ b/src/AppInstallerCLITests/TestCommon.cpp @@ -3,8 +3,9 @@ #include "pch.h" #include "TestCommon.h" #include "TestHooks.h" -#include "winget/GroupPolicy.h" -#include "winget/UserSettings.h" +#include +#include +#include namespace TestCommon { @@ -214,4 +215,74 @@ namespace TestCommon { AppInstaller::Settings::SetUserSettingsOverride(nullptr); } + + bool InstallCertFromSignedPackage(const std::filesystem::path& package) + { + auto [certContext, certStore] = AppInstaller::Msix::GetCertContextFromMsix(package); + + wil::unique_hcertstore trustedPeopleStore; + trustedPeopleStore.reset(CertOpenStore( + CERT_STORE_PROV_SYSTEM_W, + PKCS_7_ASN_ENCODING | X509_ASN_ENCODING, + NULL, + CERT_SYSTEM_STORE_LOCAL_MACHINE, + L"TrustedPeople")); + THROW_LAST_ERROR_IF(!trustedPeopleStore.get()); + + wil::unique_cert_context existingCert; + existingCert.reset(CertFindCertificateInStore( + trustedPeopleStore.get(), + PKCS_7_ASN_ENCODING | X509_ASN_ENCODING, + 0, + CERT_FIND_EXISTING, + certContext.get(), + nullptr)); + + // Add if it does not already exist in the store + if (!existingCert.get()) + { + THROW_LAST_ERROR_IF(!CertAddCertificateContextToStore( + trustedPeopleStore.get(), + certContext.get(), + CERT_STORE_ADD_NEW, + nullptr)); + + return true; + } + + return false; + } + + bool UninstallCertFromSignedPackage(const std::filesystem::path& package) + { + auto [certContext, certStore] = AppInstaller::Msix::GetCertContextFromMsix(package); + + wil::unique_hcertstore trustedPeopleStore; + trustedPeopleStore.reset(CertOpenStore( + CERT_STORE_PROV_SYSTEM_W, + PKCS_7_ASN_ENCODING | X509_ASN_ENCODING, + NULL, + CERT_SYSTEM_STORE_LOCAL_MACHINE, + L"TrustedPeople")); + THROW_LAST_ERROR_IF(!trustedPeopleStore.get()); + + wil::unique_cert_context existingCert; + existingCert.reset(CertFindCertificateInStore( + trustedPeopleStore.get(), + PKCS_7_ASN_ENCODING | X509_ASN_ENCODING, + 0, + CERT_FIND_EXISTING, + certContext.get(), + nullptr)); + + // Remove if it exists in the store + if (existingCert.get()) + { + THROW_LAST_ERROR_IF(!CertDeleteCertificateFromStore(existingCert.get())); + + return true; + } + + return false; + } } diff --git a/src/AppInstallerCLITests/TestCommon.h b/src/AppInstallerCLITests/TestCommon.h index d714073f49..41925af841 100644 --- a/src/AppInstallerCLITests/TestCommon.h +++ b/src/AppInstallerCLITests/TestCommon.h @@ -130,4 +130,9 @@ namespace TestCommon m_settings[S].emplace(std::move(value)); } }; + + // Below cert installation/uninstallation methods require admin privilege, + // tests calling these functions should skip when not running with admin. + bool InstallCertFromSignedPackage(const std::filesystem::path& package); + bool UninstallCertFromSignedPackage(const std::filesystem::path& package); } diff --git a/src/AppInstallerCLITests/TestData/index.1.0.0.0.signed.msix b/src/AppInstallerCLITests/TestData/index.1.0.0.0.signed.msix new file mode 100644 index 0000000000..0e97a4b40a Binary files /dev/null and b/src/AppInstallerCLITests/TestData/index.1.0.0.0.signed.msix differ diff --git a/src/AppInstallerCLITests/TestData/index.2.0.0.0.signed.msix b/src/AppInstallerCLITests/TestData/index.2.0.0.0.signed.msix new file mode 100644 index 0000000000..4aa495cfeb Binary files /dev/null and b/src/AppInstallerCLITests/TestData/index.2.0.0.0.signed.msix differ diff --git a/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj b/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj index 043e8091f4..cc99c18a24 100644 --- a/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj +++ b/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj @@ -308,6 +308,7 @@ + @@ -367,6 +368,7 @@ + diff --git a/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj.filters b/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj.filters index 86d71438cc..df7c55bd74 100644 --- a/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj.filters +++ b/src/AppInstallerCommonCore/AppInstallerCommonCore.vcxproj.filters @@ -195,6 +195,9 @@ Public\winget + + Public\winget + @@ -338,6 +341,9 @@ Source Files + + Source Files + diff --git a/src/AppInstallerCommonCore/ManagedFile.cpp b/src/AppInstallerCommonCore/ManagedFile.cpp new file mode 100644 index 0000000000..9d0ad500c8 --- /dev/null +++ b/src/AppInstallerCommonCore/ManagedFile.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include "pch.h" +#include "winget/ManagedFile.h" +#include "AppInstallerLogging.h" + +namespace AppInstaller::Utility +{ + ManagedFile ManagedFile::CreateWriteLockedFile(const std::filesystem::path& path, DWORD desiredAccess, bool deleteOnExit) + { + ManagedFile file; + file.m_fileHandle.reset(CreateFileW(path.c_str(), desiredAccess, FILE_SHARE_READ, nullptr, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr)); + THROW_LAST_ERROR_IF(!file.m_fileHandle); + file.m_filePath = path; + file.m_deleteFileOnExit = deleteOnExit; + + return file; + } + + ManagedFile ManagedFile::OpenWriteLockedFile(const std::filesystem::path& path, DWORD desiredAccess) + { + ManagedFile file; + file.m_fileHandle.reset(CreateFileW(path.c_str(), desiredAccess, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)); + THROW_LAST_ERROR_IF(!file.m_fileHandle); + file.m_filePath = path; + + return file; + } + + ManagedFile::~ManagedFile() + { + if (m_deleteFileOnExit) + { + if (m_fileHandle) + { + m_fileHandle.reset(); + } + + try + { + std::filesystem::remove(m_filePath); + } + catch (...) + { + AICLI_LOG(Core, Info, << "Failed to remove managed file at: " << m_filePath); + } + } + } +} \ No newline at end of file diff --git a/src/AppInstallerCommonCore/MsixInfo.cpp b/src/AppInstallerCommonCore/MsixInfo.cpp index 0f99ae75dc..8dc7b93bf3 100644 --- a/src/AppInstallerCommonCore/MsixInfo.cpp +++ b/src/AppInstallerCommonCore/MsixInfo.cpp @@ -16,6 +16,10 @@ namespace AppInstaller::Msix { namespace { + // MSIX-specific header placed in the P7X file, before the actual signature + const byte P7xFileId[] = { 0x50, 0x4b, 0x43, 0x58 }; + const DWORD P7xFileIdSize = sizeof(P7xFileId); + // Gets the version from the manifest reader. UINT64 GetVersionFromManifestReader(IAppxManifestReader* reader) { @@ -63,6 +67,8 @@ namespace AppInstaller::Msix // If we got bytes, just accept them and keep going. LOG_IF_FAILED(hr); + THROW_HR_IF_MSG(E_UNEXPECTED, expectedSize && totalBytesRead + bytesRead > expectedSize, "Read more bytes than expected size"); + file.write(buffer.get(), bytesRead); totalBytesRead += bytesRead; progress.OnProgress(totalBytesRead, expectedSize, ProgressType::Bytes); @@ -117,6 +123,161 @@ namespace AppInstaller::Msix WriteStreamToFile(stream.Get(), size, target, progress); } + + // Writes the stream (from current location) to the given file handle. + void WriteStreamToFileHandle(IStream* stream, UINT64 expectedSize, HANDLE target, IProgressCallback& progress) + { + constexpr ULONG bufferSize = 1 << 20; + std::unique_ptr buffer = std::make_unique(bufferSize); + + UINT64 totalBytesRead = 0; + + while (!progress.IsCancelled()) + { + ULONG bytesRead = 0; + HRESULT hr = stream->Read(buffer.get(), bufferSize, &bytesRead); + + if (bytesRead) + { + // If we got bytes, just accept them and keep going. + LOG_IF_FAILED(hr); + + THROW_HR_IF_MSG(E_UNEXPECTED, expectedSize && totalBytesRead + bytesRead > expectedSize, "Read more bytes than expected size"); + + DWORD bytesWritten = 0; + THROW_LAST_ERROR_IF(!WriteFile(target, buffer.get(), bytesRead, &bytesWritten, nullptr)); + THROW_HR_IF(E_UNEXPECTED, bytesRead != bytesWritten); + totalBytesRead += bytesRead; + progress.OnProgress(totalBytesRead, expectedSize, ProgressType::Bytes); + } + else + { + // If given a size, and we have read it all, quit + if (expectedSize && totalBytesRead == expectedSize) + { + break; + } + + // If the stream returned an error, throw it + THROW_IF_FAILED(hr); + + // If we were given a size and didn't reach it, throw our own error; + // otherwise assume that this is just normal EOF. + if (expectedSize) + { + THROW_WIN32(ERROR_HANDLE_EOF); + } + else + { + break; + } + } + } + } + + // Writes the appx file to the given file handle. + void WriteAppxFileToFileHandle(IAppxFile* appxFile, HANDLE target, IProgressCallback& progress) + { + UINT64 size = 0; + THROW_IF_FAILED(appxFile->GetSize(&size)); + + ComPtr stream; + THROW_IF_FAILED(appxFile->GetStream(&stream)); + + WriteStreamToFileHandle(stream.Get(), size, target, progress); + } + + bool ValidateMsixTrustInfo(const std::filesystem::path& msixPath, bool verifyMicrosoftOrigin) + { + bool result = false; + AICLI_LOG(Core, Info, << "Started trust validation of msix at: " << msixPath); + + try + { + bool verifyChainResult = false; + + // First verify certificate chain if requested. + if (verifyMicrosoftOrigin) + { + auto [certContext, certStore] = GetCertContextFromMsix(msixPath); + + // Get certificate chain context for validation + CERT_CHAIN_PARA certChainParameters = { 0 }; + certChainParameters.cbSize = sizeof(CERT_CHAIN_PARA); + certChainParameters.RequestedUsage.dwType = USAGE_MATCH_TYPE_AND; + DWORD certChainFlags = CERT_CHAIN_CACHE_ONLY_URL_RETRIEVAL; + + wil::unique_cert_chain_context certChainContext; + THROW_LAST_ERROR_IF(!CertGetCertificateChain( + HCCE_LOCAL_MACHINE, + certContext.get(), + NULL, // Use the current system time for CRL validation + certStore.get(), + &certChainParameters, + certChainFlags, + NULL, // Reserved parameter; must be NULL + &certChainContext)); + + // Validate that the certificate chain is rooted in one of the well-known Microsoft root certs + CERT_CHAIN_POLICY_PARA policyParameters = { 0 }; + policyParameters.cbSize = sizeof(CERT_CHAIN_POLICY_PARA); + policyParameters.dwFlags = MICROSOFT_ROOT_CERT_CHAIN_POLICY_CHECK_APPLICATION_ROOT_FLAG; + CERT_CHAIN_POLICY_STATUS policyStatus = { 0 }; + policyStatus.cbSize = sizeof(CERT_CHAIN_POLICY_STATUS); + LPCSTR policyOid = CERT_CHAIN_POLICY_MICROSOFT_ROOT; + BOOL certChainVerifySucceeded = CertVerifyCertificateChainPolicy( + policyOid, + certChainContext.get(), + &policyParameters, + &policyStatus); + + AICLI_LOG(Core, Info, << "Result for certificate chain validation of Microsoft origin: " << policyStatus.dwError); + + verifyChainResult = certChainVerifySucceeded && policyStatus.dwError == ERROR_SUCCESS; + } + else + { + verifyChainResult = true; + } + + // If certificate chain origin validation is success or not requested, then validate the trust info of the file. + if (verifyChainResult) + { + // Set up the structures needed for the WinVerifyTrust call + WINTRUST_FILE_INFO fileInfo = { 0 }; + fileInfo.cbStruct = sizeof(WINTRUST_FILE_INFO); + fileInfo.pcwszFilePath = msixPath.c_str(); + + WINTRUST_DATA trustData = { 0 }; + trustData.cbStruct = sizeof(WINTRUST_DATA); + trustData.dwUIChoice = WTD_UI_NONE; + trustData.fdwRevocationChecks = WTD_REVOKE_WHOLECHAIN; + trustData.dwUnionChoice = WTD_CHOICE_FILE; + trustData.dwStateAction = WTD_STATEACTION_VERIFY; + trustData.dwProvFlags = WTD_CACHE_ONLY_URL_RETRIEVAL; + trustData.pFile = &fileInfo; + + GUID verifyActionId = WINTRUST_ACTION_GENERIC_VERIFY_V2; + + HRESULT verifyTrustResult = static_cast(WinVerifyTrust(static_cast(INVALID_HANDLE_VALUE), &verifyActionId, &trustData)); + AICLI_LOG(Core, Info, << "Result for trust info validation of the msix: " << verifyTrustResult); + + result = verifyTrustResult == S_OK; + } + } + catch (const wil::ResultException& re) + { + AICLI_LOG(Core, Error, << "Failed during msix trust validation. Error: " << re.GetErrorCode()); + result = false; + } + catch (...) + { + AICLI_LOG(Core, Error, << "Failed during msix trust validation."); + result = false; + } + + return result; + } } bool GetBundleReader( @@ -278,6 +439,69 @@ namespace AppInstaller::Msix return { result }; } + GetCertContextResult GetCertContextFromMsix(const std::filesystem::path& msixPath) + { + // Retrieve raw signature from msix + MsixInfo msixInfo{ msixPath }; + auto signature = msixInfo.GetSignature(true); + + // Get the cert content + wil::unique_any signedMessage; + wil::unique_hcertstore certStore; + CRYPT_DATA_BLOB signatureBlob = { 0 }; + signatureBlob.cbData = static_cast(signature.size()); + signatureBlob.pbData = signature.data(); + THROW_LAST_ERROR_IF(!CryptQueryObject( + CERT_QUERY_OBJECT_BLOB, + &signatureBlob, + CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED, + CERT_QUERY_FORMAT_FLAG_BINARY, + 0, // Reserved parameter + NULL, // No encoding info needed + NULL, + NULL, + &certStore, + &signedMessage, + NULL)); + + // Get the signer size and information from the signed data message + // The properties of the signer info will be used to uniquely identify the signing certificate in the certificate store + DWORD signerInfoSize = 0; + THROW_LAST_ERROR_IF(!CryptMsgGetParam( + signedMessage.get(), + CMSG_SIGNER_INFO_PARAM, + 0, + NULL, + &signerInfoSize)); + + // Check that the signer info size is within reasonable bounds; under the max length of a string for the issuer field + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_DATA), !(signerInfoSize > 0 && signerInfoSize < STRSAFE_MAX_CCH)); + + std::vector signerInfoBuffer; + signerInfoBuffer.resize(signerInfoSize); + THROW_LAST_ERROR_IF(!CryptMsgGetParam( + signedMessage.get(), + CMSG_SIGNER_INFO_PARAM, + 0, + signerInfoBuffer.data(), + &signerInfoSize)); + + // Get the signing certificate from the certificate store based on the issuer and serial number of the signer info + CMSG_SIGNER_INFO* signerInfo = reinterpret_cast(signerInfoBuffer.data()); + CERT_INFO certInfo; + certInfo.Issuer = signerInfo->Issuer; + certInfo.SerialNumber = signerInfo->SerialNumber; + + wil::unique_cert_context certContext; + certContext.reset(CertGetSubjectCertificateFromStore( + certStore.get(), + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + &certInfo)); + THROW_LAST_ERROR_IF(!certContext.get()); + + return { std::move(certContext), std::move(certStore) }; + } + MsixInfo::MsixInfo(std::string_view uriStr) { if (Utility::IsUrlRemote(uriStr)) @@ -313,7 +537,7 @@ namespace AppInstaller::Msix } } - std::vector MsixInfo::GetSignature() + std::vector MsixInfo::GetSignature(bool skipP7xFileId) { ComPtr signatureFile; if (m_isBundle) @@ -335,6 +559,18 @@ namespace AppInstaller::Msix THROW_IF_FAILED(signatureStream->Stat(&stat, STATFLAG_NONAME)); THROW_HR_IF(E_UNEXPECTED, stat.cbSize.HighPart != 0); // Signature size should be small signatureSize = stat.cbSize.LowPart; + THROW_HR_IF(E_UNEXPECTED, signatureSize <= P7xFileIdSize); + + if (skipP7xFileId) + { + // Validate msix signature header + byte headerBuffer[P7xFileIdSize]; + DWORD headerRead; + THROW_IF_FAILED(signatureStream->Read(headerBuffer, P7xFileIdSize, &headerRead)); + THROW_HR_IF_MSG(E_UNEXPECTED, headerRead != P7xFileIdSize, "Failed to read signature header"); + THROW_HR_IF_MSG(E_UNEXPECTED, !std::equal(P7xFileId, P7xFileId + P7xFileIdSize, headerBuffer), "Unexpected msix signature header"); + signatureSize -= P7xFileIdSize; + } signatureContent.resize(signatureSize); @@ -372,16 +608,16 @@ namespace AppInstaller::Msix return Utility::ConvertToUTF8(GetPackageFullNameWide()); } - bool MsixInfo::IsNewerThan(const std::filesystem::path& otherManifest) + bool MsixInfo::IsNewerThan(const std::filesystem::path& otherPackage) { THROW_HR_IF(E_NOT_VALID_STATE, m_isBundle); - ComPtr otherStream; - THROW_IF_FAILED(SHCreateStreamOnFileEx(otherManifest.c_str(), - STGM_READ | STGM_SHARE_DENY_WRITE | STGM_FAILIFTHERE, 0, FALSE, nullptr, &otherStream)); + MsixInfo other{ otherPackage }; + + THROW_HR_IF(E_INVALIDARG, other.m_isBundle); ComPtr otherReader; - GetManifestReader(otherStream.Get(), &otherReader); + THROW_IF_FAILED(other.m_packageReader->GetManifest(&otherReader)); ComPtr manifestReader; THROW_IF_FAILED(m_packageReader->GetManifest(&manifestReader)); @@ -430,4 +666,31 @@ namespace AppInstaller::Msix WriteAppxFileToFile(appxFile.Get(), target, progress); } + + void MsixInfo::WriteToFileHandle(std::string_view packageFile, HANDLE target, IProgressCallback& progress) + { + std::wstring fileUTF16 = Utility::ConvertToUTF16(packageFile); + + ComPtr appxFile; + if (m_isBundle) + { + THROW_IF_FAILED(m_bundleReader->GetPayloadPackage(fileUTF16.c_str(), &appxFile)); + } + else + { + THROW_IF_FAILED(m_packageReader->GetPayloadFile(fileUTF16.c_str(), &appxFile)); + } + + WriteAppxFileToFileHandle(appxFile.Get(), target, progress); + } + + WriteLockedMsixFile::WriteLockedMsixFile(const std::filesystem::path& path) + { + m_file = Utility::ManagedFile::OpenWriteLockedFile(path, 0); + } + + bool WriteLockedMsixFile::ValidateTrustInfo(bool checkMicrosoftOrigin) const + { + return ValidateMsixTrustInfo(m_file.GetFilePath(), checkMicrosoftOrigin); + } } diff --git a/src/AppInstallerCommonCore/Public/AppInstallerMsixInfo.h b/src/AppInstallerCommonCore/Public/AppInstallerMsixInfo.h index dd95b7e21e..2118ba37b5 100644 --- a/src/AppInstallerCommonCore/Public/AppInstallerMsixInfo.h +++ b/src/AppInstallerCommonCore/Public/AppInstallerMsixInfo.h @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once -#include +#include "AppInstallerProgress.h" +#include "winget/ManagedFile.h" + #include #include @@ -47,6 +49,9 @@ namespace AppInstaller::Msix { MsixInfo(std::string_view uriStr); + template, int> = 0> + MsixInfo(const T& path) : MsixInfo(path.u8string()) {} + MsixInfo(const MsixInfo&) = default; MsixInfo& operator=(const MsixInfo&) = default; @@ -59,14 +64,15 @@ namespace AppInstaller::Msix } // Full content of AppxSignature.p7x - std::vector GetSignature(); + // If skipP7xFileId is true, returns content of converted .p7s + std::vector GetSignature(bool skipP7xFileId = false); // Gets the package full name. std::wstring GetPackageFullNameWide(); std::string GetPackageFullName(); - // Gets a value indicating whether the referenced info is newer than the given manifest. - bool IsNewerThan(const std::filesystem::path& otherManifest); + // Gets a value indicating whether the referenced info is newer than the given package. + bool IsNewerThan(const std::filesystem::path& otherPackage); bool IsNewerThan(const winrt::Windows::ApplicationModel::PackageVersion& otherVersion); @@ -76,10 +82,32 @@ namespace AppInstaller::Msix // Writes the package's manifest to the given path. void WriteManifestToFile(const std::filesystem::path& target, IProgressCallback& progress); + // Writes the package file to the given file handle. + void WriteToFileHandle(std::string_view packageFile, HANDLE target, IProgressCallback& progress); + private: bool m_isBundle; Microsoft::WRL::ComPtr m_stream; Microsoft::WRL::ComPtr m_bundleReader; Microsoft::WRL::ComPtr m_packageReader; }; + + struct GetCertContextResult + { + wil::unique_cert_context CertContext; + wil::unique_hcertstore CertStore; + }; + + // Get cert context from a signed msix/msixbundle file. + GetCertContextResult GetCertContextFromMsix(const std::filesystem::path& msixPath); + + struct WriteLockedMsixFile + { + WriteLockedMsixFile(const std::filesystem::path& path); + + bool ValidateTrustInfo(bool checkMicrosoftOrigin) const; + + private: + Utility::ManagedFile m_file; + }; } \ No newline at end of file diff --git a/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h b/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h index aff42a9258..9d757b1411 100644 --- a/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h +++ b/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h @@ -64,6 +64,9 @@ namespace AppInstaller::Runtime // Gets the path to the requested location. std::filesystem::path GetPathTo(PathName path); + // Gets a new temp file path. + std::filesystem::path GetNewTempFilePath(); + // Determines whether the current OS version is >= the given one. // We treat the given Version struct as a standard 4 part Windows OS version. bool IsCurrentOSVersionGreaterThanOrEqual(const Utility::Version& version); diff --git a/src/AppInstallerCommonCore/Public/winget/ManagedFile.h b/src/AppInstallerCommonCore/Public/winget/ManagedFile.h new file mode 100644 index 0000000000..7288c828a7 --- /dev/null +++ b/src/AppInstallerCommonCore/Public/winget/ManagedFile.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#pragma once +#include +#include + +namespace AppInstaller::Utility +{ + // Struct that holds a file handle and may perform additional operations on exit + struct ManagedFile + { + ManagedFile() = default; + + ManagedFile(const ManagedFile&) = delete; + ManagedFile& operator=(const ManagedFile&) = delete; + + ManagedFile(ManagedFile&&) = default; + ManagedFile& operator=(ManagedFile&&) = default; + + HANDLE GetFileHandle() const { return m_fileHandle.get(); } + const std::filesystem::path& GetFilePath() const { return m_filePath; } + + // Always creates a new write locked file at the path given. desiredAccess is passed to CreateFile call. + static ManagedFile CreateWriteLockedFile(const std::filesystem::path& path, DWORD desiredAccess, bool deleteOnExit); + + // Always opens an existing file at the path given with write locked. desiredAccess is passed to CreateFile call. + static ManagedFile OpenWriteLockedFile(const std::filesystem::path& path, DWORD desiredAccess); + + ~ManagedFile(); + + private: + std::filesystem::path m_filePath; + wil::unique_handle m_fileHandle; + bool m_deleteFileOnExit = false; + }; +} \ No newline at end of file diff --git a/src/AppInstallerCommonCore/Runtime.cpp b/src/AppInstallerCommonCore/Runtime.cpp index 955f7904e4..655f62a704 100644 --- a/src/AppInstallerCommonCore/Runtime.cpp +++ b/src/AppInstallerCommonCore/Runtime.cpp @@ -504,6 +504,18 @@ namespace AppInstaller::Runtime return result; } + std::filesystem::path GetNewTempFilePath() + { + GUID guid; + THROW_IF_FAILED(CoCreateGuid(&guid)); + WCHAR tempFileName[256]; + THROW_HR_IF(E_UNEXPECTED, StringFromGUID2(guid, tempFileName, ARRAYSIZE(tempFileName)) == 0); + auto tempFilePath = Runtime::GetPathTo(Runtime::PathName::Temp); + tempFilePath /= tempFileName; + + return tempFilePath; + } + bool IsCurrentOSVersionGreaterThanOrEqual(const Utility::Version& version) { DWORD versionParts[3] = {}; diff --git a/src/AppInstallerCommonCore/pch.h b/src/AppInstallerCommonCore/pch.h index 0ef4c263c1..a4986d3f1e 100644 --- a/src/AppInstallerCommonCore/pch.h +++ b/src/AppInstallerCommonCore/pch.h @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include "TraceLogging.h" diff --git a/src/AppInstallerRepositoryCore/Microsoft/PreIndexedPackageSourceFactory.cpp b/src/AppInstallerRepositoryCore/Microsoft/PreIndexedPackageSourceFactory.cpp index 83eaa1ae5c..e80b2ede73 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/PreIndexedPackageSourceFactory.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/PreIndexedPackageSourceFactory.cpp @@ -7,6 +7,7 @@ #include #include +#include using namespace std::string_literals; using namespace std::string_view_literals; @@ -16,7 +17,6 @@ namespace AppInstaller::Repository::Microsoft namespace { static constexpr std::string_view s_PreIndexedPackageSourceFactory_PackageFileName = "source.msix"sv; - static constexpr std::string_view s_PreIndexedPackageSourceFactory_AppxManifestFileName = "AppxManifest.xml"sv; static constexpr std::string_view s_PreIndexedPackageSourceFactory_IndexFileName = "index.db"sv; // TODO: This being hard coded to force using the Public directory name is not ideal. static constexpr std::string_view s_PreIndexedPackageSourceFactory_IndexFilePath = "Public\\index.db"sv; @@ -361,7 +361,7 @@ namespace AppInstaller::Repository::Microsoft } std::filesystem::path packageLocation = GetStatePathFromDetails(m_details); - packageLocation /= s_PreIndexedPackageSourceFactory_IndexFileName; + packageLocation /= s_PreIndexedPackageSourceFactory_PackageFileName; if (!std::filesystem::exists(packageLocation)) { @@ -369,7 +369,27 @@ namespace AppInstaller::Repository::Microsoft THROW_HR(APPINSTALLER_CLI_ERROR_SOURCE_DATA_MISSING); } - SQLiteIndex index = SQLiteIndex::Open(packageLocation.u8string(), SQLiteIndex::OpenDisposition::Read); + // Put a write exclusive lock on the index package. + Msix::WriteLockedMsixFile indexPackage{ packageLocation }; + + // Validate index package trust info. + THROW_HR_IF(APPINSTALLER_CLI_ERROR_SOURCE_DATA_INTEGRITY_FAILURE, !indexPackage.ValidateTrustInfo(WI_IsFlagSet(m_details.TrustLevel, SourceTrustLevel::StoreOrigin))); + + // Create a temp lock exclusive index file. + auto tempIndexFilePath = Runtime::GetNewTempFilePath(); + auto tempIndexFile = Utility::ManagedFile::CreateWriteLockedFile(tempIndexFilePath, GENERIC_WRITE, true); + + // Populate temp index file. + Msix::MsixInfo packageInfo(packageLocation); + packageInfo.WriteToFileHandle(s_PreIndexedPackageSourceFactory_IndexFilePath, tempIndexFile.GetFileHandle(), progress); + + if (progress.IsCancelled()) + { + AICLI_LOG(Repo, Info, << "Cancelling open upon request"); + return {}; + } + + SQLiteIndex index = SQLiteIndex::Open(tempIndexFile.GetFilePath().u8string(), SQLiteIndex::OpenDisposition::Immutable, std::move(tempIndexFile)); // We didn't use to store the source identifier, so we compute it here in case it's // missing from the details. @@ -389,35 +409,77 @@ namespace AppInstaller::Repository::Microsoft return std::make_shared(details); } - bool UpdateInternal(const std::string&, Msix::MsixInfo& packageInfo, const SourceDetails& details, IProgressCallback& progress) override + bool UpdateInternal(const std::string& packageLocation, Msix::MsixInfo& packageInfo, const SourceDetails& details, IProgressCallback& progress) override { // We will extract the manifest and index files directly to this location std::filesystem::path packageState = GetStatePathFromDetails(details); std::filesystem::create_directories(packageState); - std::filesystem::path manifestPath = packageState / s_PreIndexedPackageSourceFactory_AppxManifestFileName; - std::filesystem::path indexPath = packageState / s_PreIndexedPackageSourceFactory_IndexFileName; + std::filesystem::path packagePath = packageState / s_PreIndexedPackageSourceFactory_PackageFileName; - if (std::filesystem::exists(manifestPath) && std::filesystem::exists(indexPath)) + if (std::filesystem::exists(packagePath)) { - // If we already have a manifest, use it to determine if we need to update or not. - if (!packageInfo.IsNewerThan(manifestPath)) + // If we already have a trusted index package, use it to determine if we need to update or not. + Msix::WriteLockedMsixFile indexPackage{ packagePath }; + if (indexPackage.ValidateTrustInfo(WI_IsFlagSet(details.TrustLevel, SourceTrustLevel::StoreOrigin)) && + !packageInfo.IsNewerThan(packagePath)) { AICLI_LOG(Repo, Info, << "Remote source data was not newer than existing, no update needed"); return true; } } + std::filesystem::path tempPackagePath = packagePath.u8string() + ".dnld.msix"; + if (Utility::IsUrlRemote(packageLocation)) + { + AppInstaller::Utility::Download(packageLocation, tempPackagePath, AppInstaller::Utility::DownloadType::Index, progress); + } + else + { + std::filesystem::copy(packageLocation, tempPackagePath); + progress.OnProgress(100, 100, ProgressType::Percent); + } + + bool updateSuccess = false; if (progress.IsCancelled()) { AICLI_LOG(Repo, Info, << "Cancelling update upon request"); - return false; } + else + { + bool tempIndexPackageTrusted = false; - packageInfo.WriteToFile(s_PreIndexedPackageSourceFactory_IndexFilePath, indexPath, progress); - packageInfo.WriteManifestToFile(manifestPath, progress); + { + // Extra scope to release the file lock right after trust validation. + Msix::WriteLockedMsixFile tempIndexPackage{ tempPackagePath }; + tempIndexPackageTrusted = tempIndexPackage.ValidateTrustInfo(WI_IsFlagSet(details.TrustLevel, SourceTrustLevel::StoreOrigin)); + } - return true; + if (tempIndexPackageTrusted) + { + std::filesystem::rename(tempPackagePath, packagePath); + AICLI_LOG(Repo, Info, << "Source update success."); + updateSuccess = true; + } + else + { + AICLI_LOG(Repo, Error, << "Source update failed. Source package failed trust validation."); + } + } + + if (!updateSuccess) + { + try + { + std::filesystem::remove(tempPackagePath); + } + catch (...) + { + AICLI_LOG(Repo, Info, << "Failed to remove temp index file at: " << tempPackagePath); + } + } + + return updateSuccess; } bool RemoveInternal(const SourceDetails& details, IProgressCallback&) override diff --git a/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.cpp b/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.cpp index 5dd2c0c7b2..1a23bc3cfc 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.cpp @@ -45,15 +45,15 @@ namespace AppInstaller::Repository::Microsoft return result; } - SQLiteIndex SQLiteIndex::Open(const std::string& filePath, OpenDisposition disposition) + SQLiteIndex SQLiteIndex::Open(const std::string& filePath, OpenDisposition disposition, Utility::ManagedFile&& indexFile) { AICLI_LOG(Repo, Info, << "Opening SQLite Index for " << GetOpenDispositionString(disposition) << " at '" << filePath << "'"); switch (disposition) { case AppInstaller::Repository::Microsoft::SQLiteIndex::OpenDisposition::Read: - return { filePath, SQLite::Connection::OpenDisposition::ReadOnly, SQLite::Connection::OpenFlags::None }; + return { filePath, SQLite::Connection::OpenDisposition::ReadOnly, SQLite::Connection::OpenFlags::None, std::move(indexFile) }; case AppInstaller::Repository::Microsoft::SQLiteIndex::OpenDisposition::ReadWrite: - return { filePath, SQLite::Connection::OpenDisposition::ReadWrite, SQLite::Connection::OpenFlags::None }; + return { filePath, SQLite::Connection::OpenDisposition::ReadWrite, SQLite::Connection::OpenFlags::None, std::move(indexFile) }; case AppInstaller::Repository::Microsoft::SQLiteIndex::OpenDisposition::Immutable: { // Following the algorithm set forth at https://sqlite.org/uri.html [3.1] to convert to a URI path @@ -99,15 +99,15 @@ namespace AppInstaller::Repository::Microsoft target += "?immutable=1"; - return { target, SQLite::Connection::OpenDisposition::ReadOnly, SQLite::Connection::OpenFlags::Uri }; + return { target, SQLite::Connection::OpenDisposition::ReadOnly, SQLite::Connection::OpenFlags::Uri, std::move(indexFile) }; } default: THROW_HR(E_UNEXPECTED); } } - SQLiteIndex::SQLiteIndex(const std::string& target, SQLite::Connection::OpenDisposition disposition, SQLite::Connection::OpenFlags flags) : - m_dbconn(SQLite::Connection::Create(target, disposition, flags)) + SQLiteIndex::SQLiteIndex(const std::string& target, SQLite::Connection::OpenDisposition disposition, SQLite::Connection::OpenFlags flags, Utility::ManagedFile&& indexFile) : + m_dbconn(SQLite::Connection::Create(target, disposition, flags)), m_indexFile(std::move(indexFile)) { m_dbconn.EnableICU(); m_version = Schema::Version::GetSchemaVersion(m_dbconn); diff --git a/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.h b/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.h index 092abc4153..05fe0e613d 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.h +++ b/src/AppInstallerRepositoryCore/Microsoft/SQLiteIndex.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -57,7 +58,7 @@ namespace AppInstaller::Repository::Microsoft }; // Opens an existing index database. - static SQLiteIndex Open(const std::string& filePath, OpenDisposition disposition); + static SQLiteIndex Open(const std::string& filePath, OpenDisposition disposition, Utility::ManagedFile&& indexFile = {}); // Gets the schema version of the index. Schema::Version GetVersion() const { return m_version; } @@ -151,7 +152,7 @@ namespace AppInstaller::Repository::Microsoft std::vector> GetDependentsById(AppInstaller::Manifest::string_t packageId) const; private: // Constructor used to open an existing index. - SQLiteIndex(const std::string& target, SQLite::Connection::OpenDisposition disposition, SQLite::Connection::OpenFlags flags); + SQLiteIndex(const std::string& target, SQLite::Connection::OpenDisposition disposition, SQLite::Connection::OpenFlags flags, Utility::ManagedFile&& indexFile); // Constructor used to create a new index. SQLiteIndex(const std::string& target, Schema::Version version); @@ -163,6 +164,7 @@ namespace AppInstaller::Repository::Microsoft // Sets the last write time metadata value in the index. void SetLastWriteTime(); + Utility::ManagedFile m_indexFile; SQLite::Connection m_dbconn; Schema::Version m_version; std::unique_ptr m_interface;