Skip to content

Commit

Permalink
Add cancellation token parameter to sync convenience methods (#5337)
Browse files Browse the repository at this point in the history
Previously, only async convenience methods had cancellation tokens, but
the guideline is that both sync and async methods should have
cancellation tokens.
  • Loading branch information
JoshLove-msft authored Dec 11, 2024
1 parent 5383844 commit 530c657
Show file tree
Hide file tree
Showing 242 changed files with 750 additions and 722 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private ScmMethodProvider BuildConvenienceMethod(MethodProvider protocolMethod,
methodModifier,
GetResponseType(Operation.Responses, true, isAsync, out var responseBodyType),
null,
isAsync ? [.. ConvenienceMethodParameters, ScmKnownParameters.CancellationToken] : ConvenienceMethodParameters);
[.. ConvenienceMethodParameters, ScmKnownParameters.CancellationToken]);

MethodBodyStatement[] methodBody;

Expand All @@ -84,15 +84,15 @@ private ScmMethodProvider BuildConvenienceMethod(MethodProvider protocolMethod,
methodBody =
[
.. GetStackVariablesForProtocolParamConversion(ConvenienceMethodParameters, out var declarations),
Return(This.Invoke(protocolMethod.Signature, [.. GetProtocolMethodArguments(ConvenienceMethodParameters, declarations, isAsync)], isAsync))
Return(This.Invoke(protocolMethod.Signature, [.. GetProtocolMethodArguments(ConvenienceMethodParameters, declarations)], isAsync))
];
}
else
{
methodBody =
[
.. GetStackVariablesForProtocolParamConversion(ConvenienceMethodParameters, out var paramDeclarations),
Declare("result", This.Invoke(protocolMethod.Signature, [.. GetProtocolMethodArguments(ConvenienceMethodParameters, paramDeclarations, isAsync)], isAsync).ToApi<ClientResponseApi>(), out ClientResponseApi result),
Declare("result", This.Invoke(protocolMethod.Signature, [.. GetProtocolMethodArguments(ConvenienceMethodParameters, paramDeclarations)], isAsync).ToApi<ClientResponseApi>(), out ClientResponseApi result),
.. GetStackVariablesForReturnValueConversion(result, responseBodyType, isAsync, out var resultDeclarations),
Return(result.FromValue(GetResultConversion(result, result.GetRawResponse(), responseBodyType, resultDeclarations), result.GetRawResponse())),
];
Expand Down Expand Up @@ -314,8 +314,7 @@ private ValueExpression GetResultConversion(ClientResponseApi result, HttpRespon

private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(
IReadOnlyList<ParameterProvider> convenienceMethodParameters,
Dictionary<string, ValueExpression> declarations,
bool isAsync)
Dictionary<string, ValueExpression> declarations)
{
List<ValueExpression> conversions = new List<ValueExpression>();
bool addedSpreadSource = false;
Expand Down Expand Up @@ -362,12 +361,10 @@ private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(
conversions.Add(param);
}
}

// RequestOptions argument
conversions.Add(IHttpRequestOptionsApiSnippets.FromCancellationToken(ScmKnownParameters.CancellationToken));

conversions.Add(
isAsync
? IHttpRequestOptionsApiSnippets.FromCancellationToken(ScmKnownParameters.CancellationToken)
: ScmKnownParameters.RequestOptions.PositionalReference(Null));
return conversions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public async Task CanAddMethodSameName()
var helloAgainMethod = clientProviderMethods.FirstOrDefault(m
=> m.Signature.Name == "HelloAgain" && m.Signature.Parameters.Count > 0 && m.Signature.Parameters[0].Name == "p1");
Assert.IsNotNull(helloAgainMethod);
Assert.AreEqual(1, helloAgainMethod!.Signature.Parameters.Count);
Assert.AreEqual(2, helloAgainMethod!.Signature.Parameters.Count);

// The custom code view should contain the method
var customCodeView = clientProvider.CustomCodeView;
Expand Down Expand Up @@ -187,7 +187,7 @@ public async Task CanReplaceStructMethod(bool isStructCustomized)
Assert.AreEqual("HelloAgain", customMethods[0].Signature.Name);

var customMethodParams = customMethods[0].Signature.Parameters;
Assert.AreEqual(1, customMethodParams.Count);
Assert.AreEqual(2, customMethodParams.Count);
Assert.AreEqual("p1", customMethodParams[0].Name);
Assert.AreEqual("MyStruct", customMethodParams[0].Type.Name);
Assert.AreEqual(isStructCustomized ? "Sample.TestClient" : string.Empty, customMethodParams[0].Type.Namespace);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ public void ValidateClientWithSpread(InputClient inputClient)

var convenienceMethods = methods.Where(m => m.Signature.Parameters.Any(p => p.Type.Equals(typeof(string)))).ToList();
Assert.AreEqual(2, convenienceMethods.Count);
Assert.AreEqual(1, convenienceMethods[0].Signature.Parameters.Count);
Assert.AreEqual(2, convenienceMethods[0].Signature.Parameters.Count);

Assert.AreEqual(new CSharpType(typeof(string)), convenienceMethods[0].Signature.Parameters[0].Type);
Assert.AreEqual("p1", convenienceMethods[0].Signature.Parameters[0].Name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Sample
/// <summary></summary>
public partial class TestClient
{
public virtual ClientResult HelloAgain(MyStruct? p1)
public virtual ClientResult HelloAgain(MyStruct? p1, CancellationToken cancellationToken = default)
{

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Sample
/// <summary></summary>
public partial class TestClient
{
public virtual ClientResult HelloAgain(MyStruct? p1)
public virtual ClientResult HelloAgain(MyStruct? p1, CancellationToken cancellationToken = default)
{

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
#nullable disable

using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading;
using Sample.Models;

namespace Sample
{
/// <summary></summary>
public partial class TestClient
{
public virtual global::System.ClientModel.ClientResult Operation(global::Sample.Models.InputEnum queryParam)
public virtual global::System.ClientModel.ClientResult Operation(global::Sample.Models.InputEnum queryParam, global::System.Threading.CancellationToken cancellationToken = default)
{
return this.Operation(queryParam.ToString(), options: null);
return this.Operation(queryParam.ToString(), cancellationToken.CanBeCanceled ? new global::System.ClientModel.Primitives.RequestOptions { CancellationToken = cancellationToken } : null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ public void TestDefaultCSharpMethodCollection(InputOperation inputOperation)
if (spreadInputParameter != null)
{
var spreadModelProperties = _spreadModel.Properties;
Assert.AreEqual(spreadModelProperties.Count + 1, convenienceMethodParams.Count);
// model properties + 2 (parameter and cancellation token)
Assert.AreEqual(spreadModelProperties.Count + 2, convenienceMethodParams.Count);
Assert.AreEqual("p1", convenienceMethodParams[0].Name);
Assert.AreEqual(spreadModelProperties[0].Name, convenienceMethodParams[1].Name);
}
}

[TestCaseSource(nameof(DefaultCSharpMethodCollectionTestCases))]
public void AsyncMethodsHaveOptionalCancellationToken(InputOperation inputOperation)
public void ConvenienceMethodsHaveOptionalCancellationToken(InputOperation inputOperation)
{
var inputClient = InputFactory.Client("TestClient", operations: [inputOperation]);

Expand All @@ -87,6 +88,19 @@ public void AsyncMethodsHaveOptionalCancellationToken(InputOperation inputOperat
Assert.IsTrue(lastParameter.Type.Equals(typeof(CancellationToken)));
Assert.IsFalse(lastParameter.Type.IsNullable);
Assert.AreEqual(Snippet.Default, lastParameter.DefaultValue);

var syncConvenienceMethod = methodCollection.FirstOrDefault(m
=> !m.Signature.Parameters.Any(p => p.Name == "content")
&& m.Signature.Name == inputOperation.Name.ToCleanName());
Assert.IsNotNull(syncConvenienceMethod);

var syncConvenienceMethodParameters = syncConvenienceMethod!.Signature.Parameters;
Assert.IsNotNull(syncConvenienceMethodParameters);

lastParameter = syncConvenienceMethodParameters.Last();
Assert.IsTrue(lastParameter.Type.Equals(typeof(CancellationToken)));
Assert.IsFalse(lastParameter.Type.IsNullable);
Assert.AreEqual(Snippet.Default, lastParameter.DefaultValue);
}

public static IEnumerable<TestCaseData> DefaultCSharpMethodCollectionTestCases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ private static void ValidateConvenienceMethodParameters(MethodInfo method, IEnum
{
if (IsProtocolMethod(method))
return;
if (method.Name.EndsWith("Async"))
{
expected = expected.Append((typeof(CancellationToken), "cancellationToken", false));
}

expected = expected.Append((typeof(CancellationToken), "cancellationToken", false));

var parameters = method.GetParameters().Where(p => !p.ParameterType.Equals(typeof(RequestOptions)));
var parameterTypes = parameters.Select(p => p.ParameterType);
var parameterNames = parameters.Select(p => p.Name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ public partial class ApiKeyClient

public virtual Task<ClientResult> ValidAsync(RequestOptions options) => throw null;

public virtual ClientResult Valid() => throw null;
public virtual ClientResult Valid(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ValidAsync(CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Invalid(RequestOptions options) => throw null;

public virtual Task<ClientResult> InvalidAsync(RequestOptions options) => throw null;

public virtual ClientResult Invalid() => throw null;
public virtual ClientResult Invalid(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> InvalidAsync(CancellationToken cancellationToken = default) => throw null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ public partial class CustomClient

public virtual Task<ClientResult> ValidAsync(RequestOptions options) => throw null;

public virtual ClientResult Valid() => throw null;
public virtual ClientResult Valid(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ValidAsync(CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Invalid(RequestOptions options) => throw null;

public virtual Task<ClientResult> InvalidAsync(RequestOptions options) => throw null;

public virtual ClientResult Invalid() => throw null;
public virtual ClientResult Invalid(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> InvalidAsync(CancellationToken cancellationToken = default) => throw null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ public partial class OAuth2Client

public virtual Task<ClientResult> ValidAsync(RequestOptions options) => throw null;

public virtual ClientResult Valid() => throw null;
public virtual ClientResult Valid(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ValidAsync(CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Invalid(RequestOptions options) => throw null;

public virtual Task<ClientResult> InvalidAsync(RequestOptions options) => throw null;

public virtual ClientResult Invalid() => throw null;
public virtual ClientResult Invalid(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> InvalidAsync(CancellationToken cancellationToken = default) => throw null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ public partial class UnionClient

public virtual Task<ClientResult> ValidKeyAsync(RequestOptions options) => throw null;

public virtual ClientResult ValidKey() => throw null;
public virtual ClientResult ValidKey(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ValidKeyAsync(CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult ValidToken(RequestOptions options) => throw null;

public virtual Task<ClientResult> ValidTokenAsync(RequestOptions options) => throw null;

public virtual ClientResult ValidToken() => throw null;
public virtual ClientResult ValidToken(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ValidTokenAsync(CancellationToken cancellationToken = default) => throw null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ public partial class ClientModel

public virtual Task<ClientResult> ClientAsync(BinaryContent content, RequestOptions options = null) => throw null;

public virtual ClientResult Client(Models.ClientModel body) => throw null;
public virtual ClientResult Client(Models.ClientModel body, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ClientAsync(Models.ClientModel body, CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Language(BinaryContent content, RequestOptions options = null) => throw null;

public virtual Task<ClientResult> LanguageAsync(BinaryContent content, RequestOptions options = null) => throw null;

public virtual ClientResult Language(CSModel body) => throw null;
public virtual ClientResult Language(CSModel body, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> LanguageAsync(CSModel body, CancellationToken cancellationToken = default) => throw null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,55 +23,55 @@ public partial class NamingClient

public virtual Task<ClientResult> ClientNameAsync(RequestOptions options) => throw null;

public virtual ClientResult ClientName() => throw null;
public virtual ClientResult ClientName(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ClientNameAsync(CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Parameter(string clientName, RequestOptions options) => throw null;

public virtual Task<ClientResult> ParameterAsync(string clientName, RequestOptions options) => throw null;

public virtual ClientResult Parameter(string clientName) => throw null;
public virtual ClientResult Parameter(string clientName, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ParameterAsync(string clientName, CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Client(BinaryContent content, RequestOptions options = null) => throw null;

public virtual Task<ClientResult> ClientAsync(BinaryContent content, RequestOptions options = null) => throw null;

public virtual ClientResult Client(ClientNameModel body) => throw null;
public virtual ClientResult Client(ClientNameModel body, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ClientAsync(ClientNameModel body, CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Language(BinaryContent content, RequestOptions options = null) => throw null;

public virtual Task<ClientResult> LanguageAsync(BinaryContent content, RequestOptions options = null) => throw null;

public virtual ClientResult Language(LanguageClientNameModel body) => throw null;
public virtual ClientResult Language(LanguageClientNameModel body, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> LanguageAsync(LanguageClientNameModel body, CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult CompatibleWithEncodedName(BinaryContent content, RequestOptions options = null) => throw null;

public virtual Task<ClientResult> CompatibleWithEncodedNameAsync(BinaryContent content, RequestOptions options = null) => throw null;

public virtual ClientResult CompatibleWithEncodedName(ClientNameAndJsonEncodedNameModel body) => throw null;
public virtual ClientResult CompatibleWithEncodedName(ClientNameAndJsonEncodedNameModel body, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> CompatibleWithEncodedNameAsync(ClientNameAndJsonEncodedNameModel body, CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Request(string clientName, RequestOptions options) => throw null;

public virtual Task<ClientResult> RequestAsync(string clientName, RequestOptions options) => throw null;

public virtual ClientResult Request(string clientName) => throw null;
public virtual ClientResult Request(string clientName, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> RequestAsync(string clientName, CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult Response(RequestOptions options) => throw null;

public virtual Task<ClientResult> ResponseAsync(RequestOptions options) => throw null;

public virtual ClientResult Response() => throw null;
public virtual ClientResult Response(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> ResponseAsync(CancellationToken cancellationToken = default) => throw null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ public partial class UnionEnum

public virtual Task<ClientResult> UnionEnumNameAsync(BinaryContent content, RequestOptions options = null) => throw null;

public virtual ClientResult UnionEnumName(ClientExtensibleEnum body) => throw null;
public virtual ClientResult UnionEnumName(ClientExtensibleEnum body, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> UnionEnumNameAsync(ClientExtensibleEnum body, CancellationToken cancellationToken = default) => throw null;

public virtual ClientResult UnionEnumMemberName(BinaryContent content, RequestOptions options = null) => throw null;

public virtual Task<ClientResult> UnionEnumMemberNameAsync(BinaryContent content, RequestOptions options = null) => throw null;

public virtual ClientResult UnionEnumMemberName(ExtensibleEnum body) => throw null;
public virtual ClientResult UnionEnumMemberName(ExtensibleEnum body, CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> UnionEnumMemberNameAsync(ExtensibleEnum body, CancellationToken cancellationToken = default) => throw null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public partial class FirstClient

public virtual Task<ClientResult> OneAsync(RequestOptions options) => throw null;

public virtual ClientResult One() => throw null;
public virtual ClientResult One(CancellationToken cancellationToken = default) => throw null;

public virtual Task<ClientResult> OneAsync(CancellationToken cancellationToken = default) => throw null;

Expand Down
Loading

0 comments on commit 530c657

Please sign in to comment.