Skip to content

Commit

Permalink
Improve getting command ast and parse parameters (#13864)
Browse files Browse the repository at this point in the history
* Improve getting command ast and parse parameters.

- We use the RelatedAsts from PSReadLine to get the command AST as the
  user input. This is what our prediction will try to match.
- Now we can parse the parameter in the format of "-Name:Value".
- We don't parse  positioinal parameters yet.

* Incorporate PR feedback.
  • Loading branch information
kceiw authored Jan 25, 2021
1 parent efd4321 commit 860dc51
Show file tree
Hide file tree
Showing 18 changed files with 330 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ public void VerifyParameterValues()
Action actual = () => this._service.GetSuggestion(null, 1, 1, CancellationToken.None);
Assert.Throws<ArgumentNullException>(actual);

actual = () => this._service.GetSuggestion(predictionContext.InputAst, 0, 1, CancellationToken.None);
actual = () => this._service.GetSuggestion(predictionContext, 0, 1, CancellationToken.None);
Assert.Throws<ArgumentOutOfRangeException>(actual);

actual = () => this._service.GetSuggestion(predictionContext.InputAst, 1, 0, CancellationToken.None);
actual = () => this._service.GetSuggestion(predictionContext, 1, 0, CancellationToken.None);
Assert.Throws<ArgumentOutOfRangeException>(actual);
}

Expand All @@ -110,8 +110,8 @@ public void VerifyParameterValues()
public void VerifyUsingCommandBasedPredictor(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst;
var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var commandAst = predictionContext.RelatedAsts.OfType<CommandAst>().LastOrDefault();
var commandName = commandAst?.GetCommandName();
var inputParameterSet = new ParameterSet(commandAst);
var rawUserInput = predictionContext.InputAst.Extent.Text;
var presentCommands = new Dictionary<string, int>();
Expand All @@ -123,7 +123,7 @@ public void VerifyUsingCommandBasedPredictor(string userInput)
1,
CancellationToken.None);

var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -133,7 +133,7 @@ public void VerifyUsingCommandBasedPredictor(string userInput)
Assert.Equal<string>(expected.SourceTexts, actual.SourceTexts);
Assert.All<SuggestionSource>(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.CurrentCommand, source));

actual = this._noFallbackPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noFallbackPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -153,7 +153,7 @@ public void VerifyUsingCommandBasedPredictor(string userInput)
public void VerifyUsingFallbackPredictor(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst;
var commandAst = predictionContext.RelatedAsts.OfType<CommandAst>().LastOrDefault();
var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var inputParameterSet = new ParameterSet(commandAst);
var rawUserInput = predictionContext.InputAst.Extent.Text;
Expand All @@ -166,7 +166,7 @@ public void VerifyUsingFallbackPredictor(string userInput)
1,
CancellationToken.None);

var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -176,7 +176,7 @@ public void VerifyUsingFallbackPredictor(string userInput)
Assert.Equal<string>(expected.SourceTexts, actual.SourceTexts);
Assert.All<SuggestionSource>(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.StaticCommands, source));

actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -199,33 +199,43 @@ public void VerifyUsingFallbackPredictor(string userInput)
public void VerifyNoPrediction(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Equal(0, actual.Count);

actual = this._noFallbackPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noFallbackPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Equal(0, actual.Count);

actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Equal(0, actual.Count);

actual = this._noPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Null(actual);
}

/// <summary>
/// Verify that it returns null when we cannot parse the user input.
/// </summary>
[Theory]
[InlineData("git status")]
public void VerifyFailToParseUserInput(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Null(actual);
}

/// <summary>
/// Verify when we cannot parse the user input correctly.
/// </summary>
/// <remarks>
/// When we can parse them correctly, please move the InlineData to the corresponding test methods, for example, "git status"
/// doesn't have any prediction so it should move to <see cref="VerifyNoPrediction"/>.
/// When we can parse them correctly, please move the InlineData to the corresponding test methods.
/// </remarks>
[Theory]
[InlineData("git status")]
[InlineData("Get-AzContext Name")]
public void VerifyMalFormattedCommandLine(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
Action actual = () => this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
Action actual = () => this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
_ = Assert.Throws<InvalidOperationException>(actual);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ public void VerifySupportedAndUnsupportedCommands()
public void VerifySuggestion(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var expected = _service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = _azPredictor.GetSuggestion(predictionContext, CancellationToken.None);
var expected = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
var actual = this._azPredictor.GetSuggestion(predictionContext, CancellationToken.None);

Assert.Equal(expected.Count, actual.Count);
Assert.Equal(expected.PredictiveSuggestions.First().SuggestionText, actual.First().SuggestionText);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// limitations under the License.
// ----------------------------------------------------------------------------------

using System;
using Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry;
using System;

namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Test.Mocks
{
Expand Down Expand Up @@ -58,5 +58,10 @@ public void OnGetSuggestion(GetSuggestionTelemetryData telemetryData)
{
}

/// <inheritdoc/>
public void OnLoadParameterMap(ParameterMapTelemetryData telemetryData)
{
}

}
}
5 changes: 4 additions & 1 deletion tools/Az.Tools.Predictor/Az.Tools.Predictor.sln
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor", "Az.To
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MockPSConsole", "MockPSConsole\MockPSConsole.csproj", "{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MockPSConsole", "MockPSConsole\MockPSConsole.csproj", "{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}"
ProjectSection(ProjectDependencies) = postProject
{E4A5F697-086C-4908-B90E-A31EE47ECF13} = {E4A5F697-086C-4908-B90E-A31EE47ECF13}
EndProjectSection
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down
4 changes: 2 additions & 2 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
using System.Runtime.CompilerServices;
using System.Threading;

[assembly:InternalsVisibleTo("Microsoft.Azure.PowerShell.Tools.AzPredictor.Test")]
[assembly: InternalsVisibleTo("Microsoft.Azure.PowerShell.Tools.AzPredictor.Test")]

namespace Microsoft.Azure.PowerShell.Tools.AzPredictor
{
Expand Down Expand Up @@ -195,7 +195,7 @@ public List<PredictiveSuggestion> GetSuggestion(PredictionContext context, Cance
{
var localCancellationToken = Settings.ContinueOnTimeout ? CancellationToken.None : cancellationToken;

suggestions = _service.GetSuggestion(context.InputAst, _settings.SuggestionCount.Value, _settings.MaxAllowedCommandDuplicate.Value, localCancellationToken);
suggestions = _service.GetSuggestion(context, _settings.SuggestionCount.Value, _settings.MaxAllowedCommandDuplicate.Value, localCancellationToken);

var returnedValue = suggestions?.PredictiveSuggestions?.ToList();
return returnedValue ?? new List<PredictiveSuggestion>();
Expand Down
41 changes: 34 additions & 7 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation.Language;
using System.Management.Automation.Subsystem;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
Expand Down Expand Up @@ -81,7 +82,7 @@ private sealed class CommandRequestContext
/// </summary>
private HashSet<string> _allPredictiveCommands;
private CancellationTokenSource _predictionRequestCancellationSource;
private readonly ParameterValuePredictor _parameterValuePredictor = new ParameterValuePredictor();
private readonly ParameterValuePredictor _parameterValuePredictor;

private readonly ITelemetryClient _telemetryClient;
private readonly IAzContext _azContext;
Expand All @@ -98,6 +99,8 @@ public AzPredictorService(string serviceUri, ITelemetryClient telemetryClient, I
Validation.CheckArgument(telemetryClient, $"{nameof(telemetryClient)} cannot be null.");
Validation.CheckArgument(azContext, $"{nameof(azContext)} cannot be null.");

_parameterValuePredictor = new ParameterValuePredictor(telemetryClient);

_commandsEndpoint = $"{serviceUri}{AzPredictorConstants.CommandsEndpoint}?clientType={AzPredictorService.ClientType}&context.versionNumber={azContext.AzVersion}";
_predictionsEndpoint = serviceUri + AzPredictorConstants.PredictionsEndpoint;
_telemetryClient = telemetryClient;
Expand Down Expand Up @@ -143,22 +146,46 @@ protected virtual void Dispose(bool disposing)
/// Tries to get the suggestions for the user input from the command history. If that doesn't find
/// <paramref name="suggestionCount"/> suggestions, it'll fallback to find the suggestion regardless of command history.
/// </remarks>
public CommandLineSuggestion GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken)
public CommandLineSuggestion GetSuggestion(PredictionContext context, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken)
{
Validation.CheckArgument(input, $"{nameof(input)} cannot be null");
Validation.CheckArgument(context, $"{nameof(context)} cannot be null");
Validation.CheckArgument<ArgumentOutOfRangeException>(suggestionCount > 0, $"{nameof(suggestionCount)} must be larger than 0.");
Validation.CheckArgument<ArgumentOutOfRangeException>(maxAllowedCommandDuplicate > 0, $"{nameof(maxAllowedCommandDuplicate)} must be larger than 0.");

var commandAst = input.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst;
var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var relatedAsts = context.RelatedAsts;
CommandAst commandAst = null;

for (var i = relatedAsts.Count - 1; i >= 0; --i)
{
if (relatedAsts[i] is CommandAst c)
{
commandAst = c;
break;
}
}

var commandName = commandAst?.GetCommandName();

if (string.IsNullOrWhiteSpace(commandName))
{
return null;
}

var inputParameterSet = new ParameterSet(commandAst);
var rawUserInput = input.Extent.Text;
ParameterSet inputParameterSet = null;

try
{
inputParameterSet = new ParameterSet(commandAst);
}
catch when (!IsSupportedCommand(commandName))
{
// We only ignore the exception when the command name is not supported.
// For the supported ones, this most likely happens when positional parameters are used.
// We want to collect the telemetry about the exception how common a positional parameter is used.
return null;
}

var rawUserInput = context.InputAst.ToString();
var presentCommands = new Dictionary<string, int>();
var commandBasedPredictor = _commandBasedPredictor;
var commandToRequestPrediction = _commandToRequestPrediction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ public CommandLinePredictor(IList<PredictiveCommand> modelPredictions, Parameter
{
var predictionText = CommandLineUtilities.EscapePredictionText(predictiveCommand.Command);
Ast ast = Parser.ParseInput(predictionText, out Token[] tokens, out _);
var commandAst = (ast.Find((ast) => ast is CommandAst, searchNestedScriptBlocks: false) as CommandAst);
var commandAst = ast.Find((ast) => ast is CommandAst, searchNestedScriptBlocks: false) as CommandAst;
var commandName = commandAst?.GetCommandName();

if (commandAst?.CommandElements[0] is StringConstantExpressionAst commandName)
if (!string.IsNullOrWhiteSpace(commandName))
{
var parameterSet = new ParameterSet(commandAst);
this._commandLinePredictions.Add(new CommandLine(commandName.Value, predictiveCommand.Description, parameterSet));
this._commandLinePredictions.Add(new CommandLine(commandName, predictiveCommand.Description, parameterSet));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// limitations under the License.
// ----------------------------------------------------------------------------------

using System;
using System.Collections.Generic;
using System.Management.Automation.Language;
using System.Management.Automation.Subsystem;
using System.Threading;

namespace Microsoft.Azure.PowerShell.Tools.AzPredictor
Expand All @@ -27,12 +27,12 @@ public interface IAzPredictorService
/// <summary>
/// Gest the suggestions for the user input.
/// </summary>
/// <param name="input">User input from PSReadLine.</param>
/// <param name="context">User input context from PSReadLine.</param>
/// <param name="suggestionCount">The number of suggestion to return.</param>
/// <param name="cancellationToken">The cancellation token</param>
/// <param name="maxAllowedCommandDuplicate">The maximum amount of the same commnds in the list of predictions.</param>
/// <returns>The suggestions for <paramref name="input"/>. The maximum number of suggestions is <paramref name="suggestionCount"/>.</returns>
public CommandLineSuggestion GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken);
/// <returns>The suggestions for <paramref name="context"/>. The maximum number of suggestions is <paramref name="suggestionCount"/>. A null will be returned if there the user input context isn't valid/supported at all.</returns>
public CommandLineSuggestion GetSuggestion(PredictionContext context, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken);

/// <summary>
/// Requests predictions, given a command string.
Expand Down
27 changes: 22 additions & 5 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ namespace Microsoft.Azure.PowerShell.Tools.AzPredictor
/// does not matter to resulting prediction - the prediction should adapt to the
/// order of the parameters typed by the user.
/// </summary>
/// <remarks>
/// This doesn't handle the positional parameters yet.
/// </remarks>
sealed class ParameterSet
{
/// <summary>
Expand All @@ -36,11 +39,14 @@ public ParameterSet(CommandAst commandAst)
Validation.CheckArgument(commandAst, $"{nameof(commandAst)} cannot be null.");

var parameters = new List<Parameter>();
var elements = commandAst.CommandElements.Skip(1);
CommandParameterAst param = null;
Ast arg = null;
foreach (Ast elem in elements)

// Loop through all the parameters. The first element of CommandElements is the command name, so skip it.
for (var i = 1; i < commandAst.CommandElements.Count(); ++i)
{
var elem = commandAst.CommandElements[i];

if (elem is CommandParameterAst p)
{
AddParameter(param, arg);
Expand Down Expand Up @@ -68,11 +74,22 @@ public ParameterSet(CommandAst commandAst)

Parameters = parameters;

void AddParameter(CommandParameterAst parameterName, Ast parameterValue)
void AddParameter(CommandParameterAst parameter, Ast parameterValue)
{
if (parameterName != null)
if (parameter != null)
{
parameters.Add(new Parameter(parameterName.ParameterName, (parameterValue == null) ? null : CommandLineUtilities.UnescapePredictionText(parameterValue.ToString())));
var value = parameterValue?.ToString();
if (value == null)
{
value = parameter.Argument?.ToString();
}

if (value != null)
{
value = CommandLineUtilities.UnescapePredictionText(value);
}

parameters.Add(new Parameter(parameter.ParameterName, value));
}
}
}
Expand Down
Loading

0 comments on commit 860dc51

Please sign in to comment.