Skip to content

Commit

Permalink
Sync with canonical Async.AwaitTaskCorrect
Browse files Browse the repository at this point in the history
  • Loading branch information
bartelink committed May 19, 2022
1 parent 6f8896d commit 4c52d74
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 44 deletions.
88 changes: 46 additions & 42 deletions src/FSharp.AWS.DynamoDB/Utils/Utils.fs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,18 @@ open System.Reflection
open System.Threading.Tasks
open System.Text

open Microsoft.FSharp.Core.LanguagePrimitives.IntrinsicFunctions
open Microsoft.FSharp.Quotations
open Microsoft.FSharp.Quotations.Patterns
open Microsoft.FSharp.Quotations.DerivedPatterns
open Microsoft.FSharp.Quotations.ExprShape

[<AutoOpen>]
module internal Utils =

let inline rlist (ts : seq<'T>) = new ResizeArray<_>(ts)
let inline rlist (ts : seq<'T>) = ResizeArray<_>(ts)

let inline keyVal k v = new KeyValuePair<_,_>(k,v)
let inline keyVal k v = KeyValuePair<_,_>(k,v)

let inline cdict (kvs : seq<KeyValuePair<'K,'V>>) =
let inline cdict (kvs : seq<KeyValuePair<'K,'V>>) =
let d = new Dictionary<'K, 'V>()
for kv in kvs do d.Add(kv.Key, kv.Value)
d
Expand All @@ -34,7 +32,7 @@ module internal Utils =
/// pair hashcode generation without tuple allocation
let inline hash2 (t : 'T) (s : 'S) =
combineHash (hash t) (hash s)

/// triple hashcode generation without tuple allocation
let inline hash3 (t : 'T) (s : 'S) (u : 'U) =
combineHash (combineHash (hash t) (hash s)) (hash u)
Expand All @@ -44,17 +42,17 @@ module internal Utils =
combineHash (combineHash (combineHash (hash t) (hash s)) (hash u)) (hash v)

let inline mkString (builder : (string -> unit) -> unit) : string =
let sb = new StringBuilder()
let sb = StringBuilder()
builder (fun s -> sb.Append s |> ignore)
sb.ToString()

let tryGetAttribute<'Attribute when 'Attribute :> System.Attribute> (attrs : seq<Attribute>) : 'Attribute option =
let tryGetAttribute<'Attribute when 'Attribute :> Attribute> (attrs : seq<Attribute>) : 'Attribute option =
attrs |> Seq.tryPick(function :? 'Attribute as a -> Some a | _ -> None)

let getAttributes<'Attribute when 'Attribute :> System.Attribute> (attrs : seq<Attribute>) : 'Attribute [] =
let getAttributes<'Attribute when 'Attribute :> Attribute> (attrs : seq<Attribute>) : 'Attribute [] =
attrs |> Seq.choose(function :? 'Attribute as a -> Some a | _ -> None) |> Seq.toArray

let containsAttribute<'Attribute when 'Attribute :> System.Attribute> (attrs : seq<Attribute>) : bool =
let containsAttribute<'Attribute when 'Attribute :> Attribute> (attrs : seq<Attribute>) : bool =
attrs |> Seq.exists(fun a -> a :? 'Attribute)

[<RequireQualifiedAccess>]
Expand All @@ -66,13 +64,13 @@ module internal Utils =
| _ :: tail -> last tail

type MemberInfo with
member m.TryGetAttribute<'Attribute when 'Attribute :> System.Attribute> () : 'Attribute option =
member m.TryGetAttribute<'Attribute when 'Attribute :> Attribute> () : 'Attribute option =
m.GetCustomAttributes(true) |> Seq.map unbox<Attribute> |> tryGetAttribute

member m.GetAttributes<'Attribute when 'Attribute :> System.Attribute> () : 'Attribute [] =
member m.GetAttributes<'Attribute when 'Attribute :> Attribute> () : 'Attribute [] =
m.GetCustomAttributes(true) |> Seq.map unbox<Attribute> |> getAttributes

member m.ContainsAttribute<'Attribute when 'Attribute :> System.Attribute> () : bool =
member m.ContainsAttribute<'Attribute when 'Attribute :> Attribute> () : bool =
m.GetCustomAttributes(true) |> Seq.map unbox<Attribute> |> containsAttribute

type MethodInfo with
Expand All @@ -85,13 +83,13 @@ module internal Utils =
let gas = dt.GetGenericArguments()
let mas = m.GetGenericArguments()

let bindingFlags =
BindingFlags.Public ||| BindingFlags.NonPublic |||
BindingFlags.Static ||| BindingFlags.Instance |||
let bindingFlags =
BindingFlags.Public ||| BindingFlags.NonPublic |||
BindingFlags.Static ||| BindingFlags.Instance |||
BindingFlags.FlattenHierarchy

let m =
gt.GetMethods(bindingFlags)
let m =
gt.GetMethods(bindingFlags)
|> Array.find (fun m' -> m.Name = m'.Name && m.MetadataToken = m'.MetadataToken)

m, gas, mas
Expand All @@ -110,9 +108,9 @@ module internal Utils =
let gt = dt.GetGenericTypeDefinition()
let gas = dt.GetGenericArguments()

let bindingFlags =
BindingFlags.Public ||| BindingFlags.NonPublic |||
BindingFlags.Static ||| BindingFlags.Instance |||
let bindingFlags =
BindingFlags.Public ||| BindingFlags.NonPublic |||
BindingFlags.Static ||| BindingFlags.Instance |||
BindingFlags.FlattenHierarchy

let gp = gt.GetProperty(p.Name, bindingFlags)
Expand All @@ -121,7 +119,7 @@ module internal Utils =
else
p, [||]

type Quotations.Expr with
type Expr with
member e.IsClosed = e.GetFreeVars() |> Seq.isEmpty
member e.Substitute(v : Var, sub : Expr) =
e.Substitute(fun w -> if v = w then Some sub else None)
Expand All @@ -135,7 +133,7 @@ module internal Utils =
/// <param name="variableName">Environment variable name.</param>
static member ResolveEnvironmentVariable(variableName : string) =
let aux found target =
if String.IsNullOrWhiteSpace found then
if String.IsNullOrWhiteSpace found then
Environment.GetEnvironmentVariable(variableName, target)
else found

Expand Down Expand Up @@ -179,7 +177,7 @@ module internal Utils =

let (|IndexGet|_|) (e : Expr) =
match e with
| SpecificCall2 <@ LanguagePrimitives.IntrinsicFunctions.GetArray @> (None,_,[t], [obj ; index]) ->
| SpecificCall2 <@ LanguagePrimitives.IntrinsicFunctions.GetArray @> (None,_,[t], [obj ; index]) ->
Some(obj, t, index)
| PropertyGet(Some obj, prop, [index]) when prop.Name = "Item" ->
Some(obj, prop.PropertyType, index)
Expand Down Expand Up @@ -218,29 +216,29 @@ module internal Utils =
else None
| _ -> None

type Task with
/// Gets the inner exception of the faulted task.
member t.InnerException =
let e = t.Exception
if e.InnerExceptions.Count = 1 then e.InnerExceptions.[0]
else
e :> exn

type Async with

/// Raise an exception
static member Raise e = Async.FromContinuations(fun (_,ec,_) -> ec e)

(* Direct copies of canonical implementation at http://www.fssnip.net/7Rc/title/AsyncAwaitTaskCorrect
pending that being officially packaged somewhere or integrated into FSharp.Core https://github.com/fsharp/fslang-suggestions/issues/840 *)

/// <summary>
/// Gets the result of given task so that in the event of exception
/// the actual user exception is raised as opposed to being wrapped
/// in a System.AggregateException.
/// </summary>
/// <param name="task">Task to be awaited.</param>
[<System.Diagnostics.DebuggerStepThrough>]
static member AwaitTaskCorrect(task : Task<'T>) : Async<'T> =
Async.FromContinuations(fun (sc,ec,cc) ->
task.ContinueWith(fun (t : Task<'T>) ->
if task.IsFaulted then ec t.InnerException
elif task.IsCanceled then cc(new OperationCanceledException())
Async.FromContinuations(fun (sc, ec, _cc) ->
task.ContinueWith(fun (t : Task<'T>) ->
if t.IsFaulted then
let e = t.Exception
if e.InnerExceptions.Count = 1 then ec e.InnerExceptions[0]
else ec e
elif t.IsCanceled then ec (TaskCanceledException())
else sc t.Result)
|> ignore)

Expand All @@ -250,12 +248,18 @@ module internal Utils =
/// in a System.AggregateException.
/// </summary>
/// <param name="task">Task to be awaited.</param>
[<System.Diagnostics.DebuggerStepThrough>]
static member AwaitTaskCorrect(task : Task) : Async<unit> =
Async.FromContinuations(fun (sc,ec,cc) ->
task.ContinueWith(fun (t : Task) ->
if task.IsFaulted then ec t.InnerException
elif task.IsCanceled then cc(new OperationCanceledException())
else sc ())
Async.FromContinuations(fun (sc, ec, _cc) ->
task.ContinueWith(fun (task : Task) ->
if task.IsFaulted then
let e = task.Exception
if e.InnerExceptions.Count = 1 then ec e.InnerExceptions[0]
else ec e
elif task.IsCanceled then
ec (TaskCanceledException())
else
sc ())
|> ignore)

[<RequireQualifiedAccess>]
Expand All @@ -270,4 +274,4 @@ module internal Utils =
let getHomePath () =
match Environment.OSVersion.Platform with
| PlatformID.Unix | PlatformID.MacOSX -> Environment.GetEnvironmentVariable "HOME"
| _ -> Environment.ExpandEnvironmentVariables "%HOMEDRIVE%%HOMEPATH%"
| _ -> Environment.ExpandEnvironmentVariables "%HOMEDRIVE%%HOMEPATH%"
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
</ItemGroup>
<ItemGroup>
<PackageReference Include="FsCheck" Version="2.16.4" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.1.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.2.0" />
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5" />
</ItemGroup>
<Import Project="..\..\.paket\Paket.Restore.targets" />
</Project>

0 comments on commit 4c52d74

Please sign in to comment.