Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DOS-like wildcards in -requires #280

Merged
merged 2 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions src/vswhere.lib/CommandArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ static wstring ParseArgument(IteratorType& it, const IteratorType& end, const Co
template <class IteratorType>
static void ParseArgumentArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector<wstring>& arr);

template <class IteratorType>
static void ParseRequiresArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector<wstring>& literals, vector<wregex>& patterns);

const vector<wstring> CommandArgs::s_Products
{
L"Microsoft.VisualStudio.Product.Enterprise",
Expand Down Expand Up @@ -71,7 +74,7 @@ void CommandArgs::Parse(_In_ vector<CommandParser::Token> args)
}
else if (ArgumentEquals(arg.Value, L"requires"))
{
ParseArgumentArray(it, args.end(), arg, m_requires);
ParseRequiresArray(it, args.end(), arg, m_requires, m_requiresPattern);
hasSelection = true;
}
else if (ArgumentEquals(arg.Value, L"requiresAny"))
Expand Down Expand Up @@ -218,7 +221,7 @@ void CommandArgs::Parse(_In_ vector<CommandParser::Token> args)
void CommandArgs::Usage(_In_ Console& console) const
{
auto pos = m_path.find_last_of(L"\\");
auto path = ++pos != wstring::npos ? m_path.substr(pos) : m_path;
auto& path = ++pos != wstring::npos ? m_path.substr(pos) : m_path;

console.WriteLine(ResourceManager::FormatString(IDS_USAGE, path.c_str()));

Expand All @@ -231,6 +234,37 @@ void CommandArgs::Usage(_In_ Console& console) const
}
}

std::wregex CommandArgs::ParseRegex(_In_ const std::wstring& pattern) noexcept
{
// Reserve ~125% of the incoming pattern to hold any changes.
wstring accumulator;
accumulator.reserve(pattern.size() * 1.25);

for (auto it = pattern.begin(); it != pattern.end(); ++it)
{
switch (*it)
{
case L'.':
accumulator += L"\\.";
break;

case L'*':
accumulator += L".*";
break;

case L'?':
accumulator += L".";
break;

default:
heaths marked this conversation as resolved.
Show resolved Hide resolved
accumulator += *it;
break;
}
}

return std::move(wregex(accumulator, wregex::basic | wregex::icase | wregex::nosubs));
}

static bool ArgumentEquals(_In_ const wstring& name, _In_ LPCWSTR expect)
{
_ASSERT(expect && *expect);
Expand Down Expand Up @@ -281,3 +315,38 @@ static void ParseArgumentArray(IteratorType& it, const IteratorType& end, const
arr.push_back(it->Value);
}
}

template <class IteratorType>
static void ParseRequiresArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector<wstring>& literals, vector<wregex>& patterns)
{
wstring& param = it->Value;
auto nit = next(it);

// Require arguments if the parameter is specified.
if (nit == end || CommandParser::Token::eArgument != nit->Type)
{
auto message = ResourceManager::FormatString(IDS_E_ARGREQUIRED, param.c_str());
throw win32_error(ERROR_INVALID_PARAMETER, message);
}

while (nit != end)
{
if (CommandParser::Token::eParameter == nit->Type)
{
break;
}

++it;
++nit;

if (it->Value.find(L'*', 0) == wstring::npos && it->Value.find(L'?', 0) == wstring::npos)
{
literals.push_back(it->Value);
}
else
{
auto pattern = CommandArgs::ParseRegex(it->Value);
patterns.push_back(std::move(pattern));
}
}
}
9 changes: 9 additions & 0 deletions src/vswhere.lib/CommandArgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CommandArgs
m_productsAll(obj.m_productsAll),
m_products(obj.m_products),
m_requires(obj.m_requires),
m_requiresPattern(obj.m_requiresPattern),
m_version(obj.m_version),
m_latest(obj.m_latest),
m_legacy(obj.m_legacy),
Expand Down Expand Up @@ -72,6 +73,11 @@ class CommandArgs
return m_requires;
}

const std::vector<std::wregex>& get_RequiresPattern() const noexcept
{
return m_requiresPattern;
}

const bool get_RequiresAny() const noexcept
{
return m_requiresAny;
Expand Down Expand Up @@ -157,6 +163,8 @@ class CommandArgs
void Parse(_In_ int argc, _In_ LPCWSTR argv[]);
void Usage(_In_ Console& console) const;

static std::wregex ParseRegex(_In_ const std::wstring& pattern) noexcept;

private:
static const std::vector<std::wstring> s_Products;
static const std::wstring s_Format;
Expand All @@ -168,6 +176,7 @@ class CommandArgs
bool m_productsAll;
std::vector<std::wstring> m_products;
std::vector<std::wstring> m_requires;
std::vector<std::wregex> m_requiresPattern;
bool m_requiresAny;
std::wstring m_version;
bool m_latest;
Expand Down
3 changes: 2 additions & 1 deletion src/vswhere.lib/Formatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ void Formatter::WritePackages(_In_ ISetupInstance* pInstance)
StartArray(L"packages");

SafeArray<ISetupPackageReference*> saPackages(psaPackages);
const auto packages = saPackages.Elements();
const auto& packages = saPackages.Elements();

for (const auto& package : packages)
{
Expand Down Expand Up @@ -431,6 +431,7 @@ bool Formatter::WriteProperties(_In_ ISetupPropertyStore* pProperties, _In_opt_

SafeArray<BSTR> saNames(psaNames);

// Copy the elements so we can sort them.
auto elems = saNames.Elements();
sort(elems.begin(), elems.end(), less);

Expand Down
55 changes: 32 additions & 23 deletions src/vswhere.lib/InstanceSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using namespace std;
using std::placeholders::_1;

ci_equal InstanceSelector::s_comparer;

InstanceSelector::InstanceSelector(_In_ const CommandArgs& args, _In_ ILegacyProvider& provider, _In_opt_ ISetupHelper* pHelper) :
m_args(args),
m_provider(provider),
Expand All @@ -17,7 +19,7 @@ InstanceSelector::InstanceSelector(_In_ const CommandArgs& args, _In_ ILegacyPro
m_helper = pHelper;
if (m_helper)
{
auto version = args.get_Version();
auto& version = args.get_Version();
if (!version.empty())
{
auto hr = m_helper->ParseVersionRange(version.c_str(), &m_ullMinimumVersion, &m_ullMaximumVersion);
Expand Down Expand Up @@ -224,7 +226,7 @@ bool InstanceSelector::IsProductMatch(_In_ ISetupInstance2* pInstance) const
}

// Asterisk on command line will clear the array to find any products.
const auto products = m_args.get_Products();
const auto& products = m_args.get_Products();
if (products.empty())
{
return true;
Expand All @@ -250,21 +252,19 @@ bool InstanceSelector::IsWorkloadMatch(_In_ ISetupInstance2* pInstance) const
{
_ASSERT(pInstance);

const auto requires = m_args.get_Requires();
if (requires.empty())
// Create copies and erase elements as found.
auto literals = m_args.get_Requires();
auto literals_count = literals.size();

auto patterns = m_args.get_RequiresPattern();
auto patterns_count = patterns.size();

if (literals.empty() && patterns.empty())
{
// No workloads required matches every instance.
return true;
}

// Keep track of which requirements we matched.
typedef map<wstring, bool, ci_less> MapType;
MapType found;
for (const auto& require : requires)
{
found.emplace(make_pair(require, false));
}

LPSAFEARRAY psa = NULL;
auto hr = pInstance->GetPackages(&psa);
if (FAILED(hr))
Expand All @@ -277,25 +277,34 @@ bool InstanceSelector::IsWorkloadMatch(_In_ ISetupInstance2* pInstance) const
{
auto id = GetId(package);

auto it = found.find(id);
if (it != found.end())
for (auto it = literals.cbegin(); it != literals.cend(); ++it)
{
if (s_comparer(id, *it))
{
literals.erase(it);
goto next;
}
}

for (auto it = patterns.cbegin(); it != patterns.cend(); ++it)
{
it->second = true;
if (regex_match(id, *it))
{
patterns.erase(it);
goto next;
}
}

next: continue;
}

if (m_args.get_RequiresAny())
{
return any_of(found.begin(), found.end(), [](MapType::const_reference pair) -> bool
{
return pair.second;
});
return literals.size() < literals_count
|| patterns.size() < patterns_count;
}

return all_of(found.begin(), found.end(), [](MapType::const_reference pair) -> bool
{
return pair.second;
});
return literals.empty() && patterns.empty();
}

bool InstanceSelector::IsVersionMatch(_In_ ISetupInstance* pInstance) const
Expand Down
2 changes: 2 additions & 0 deletions src/vswhere.lib/InstanceSelector.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class InstanceSelector
std::vector<ISetupInstancePtr> Select(_In_opt_ IEnumSetupInstances* pEnum) const;

private:
static ci_equal s_comparer;

static std::wstring GetId(_In_ ISetupPackageReference* pPackageReference);
bool IsMatch(_In_ ISetupInstance* pInstance) const;
bool IsProductMatch(_In_ ISetupInstance2* pInstance) const;
Expand Down
2 changes: 1 addition & 1 deletion src/vswhere.lib/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const wstring& Module::get_Path() noexcept

const wstring& Module::get_FileVersion() noexcept
{
auto path = get_Path();
auto& path = get_Path();
if (path.empty())
{
return m_fileVersion;
Expand Down
3 changes: 3 additions & 0 deletions src/vswhere.lib/vswhere.lib.rc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ BEGIN
\n See https://aka.ms/vs/workloads for a list of product IDs.\
\n -requires arg One or more workload or component IDs required when finding instances.\
\n All specified IDs must be installed unless -requiresAny is specified.\
\n You can specify wildcards including ""?"" to match any one character,\
\n or ""*"" to match zero or more of any characters.\
\n See https://aka.ms/vs/workloads for a list of workload and component IDs.\
\n -requiresAny Find instances with any one or more workload or components IDs passed to -requires.\
\n -version arg A version range for instances to find. Example: [15.0,16.0) will find versions 15.*.\
\n See https://aka.ms/vswhere/versions for more information about versions.\
\n -latest Return only the newest version and last installed.\
\n -sort Sorts the instances from newest version and last installed to oldest.\
\n When used with ""find"", first instances are sorted then files are sorted lexigraphically.\
Expand Down
2 changes: 1 addition & 1 deletion src/vswhere/Program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void WriteLogo(_In_ const CommandArgs& args, _In_ Console& console, _In_ Module&
{
if (args.get_Logo())
{
const auto version = module.get_FileVersion();
const auto& version = module.get_FileVersion();
const auto nID = version.empty() ? IDS_PROGRAMINFO : IDS_PROGRAMINFOEX;

console.WriteLine(ResourceManager::FormatString(nID, NBGV_INFORMATIONAL_VERSION, version.c_str()));
Expand Down
52 changes: 52 additions & 0 deletions test/vswhere.test/CommandArgsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,56 @@ TEST_CLASS(CommandArgsTests)
Assert::IsFalse(args.get_Logo());
Assert::IsTrue(args.get_UTF8());
}

BEGIN_TEST_METHOD_ATTRIBUTE(Parse_Requires_Patterns)
TEST_WORKITEM(276)
END_TEST_METHOD_ATTRIBUTE()
TEST_METHOD(Parse_Requires_Patterns)
{
CommandArgs args;
args.Parse(L"vswhere.exe -requires foo ba* qux");

const auto& literals = args.get_Requires();
const auto& patterns = args.get_RequiresPattern();

Assert::AreEqual(1, count(literals.cbegin(), literals.cend(), wstring(L"foo")));
Assert::AreEqual(1, count(literals.cbegin(), literals.cend(), wstring(L"qux")));
Assert::AreEqual<size_t>(1, patterns.size());
}

BEGIN_TEST_METHOD_ATTRIBUTE(ParseRegex_Theory)
TEST_WORKITEM(276)
END_TEST_METHOD_ATTRIBUTE()
TEST_METHOD(ParseRegex_Theory)
{
const wstring id = L"Foo.Bar";
vector<tuple<wstring, bool>> data =
{
{ L"Foo.Bar", true },
{ L"Foo.*", true },
{ L"*.Bar", true },
{ L"F*R", true },
{ L"foo?bar", true },
{ L"f??", false },
{ L"f??.??r", true },
{ L"*", true },
{ L".*", false },
{ L"?", false },
{ L"Baz", false },
{ L"*baz", false },
{ L"foo.bar*", true },
};

for (const auto& item : data)
{
wstring pattern;
bool expected;

tie(pattern, expected) = item;
auto re = CommandArgs::ParseRegex(pattern);
bool actual = regex_match(id, re);

Assert::AreEqual(expected, actual, format(L"\"%ls\" =~ /%ls/", id.c_str(), pattern.c_str()).c_str());
}
}
};
Loading