diff --git a/binding/Binding/HandleDictionary.cs b/binding/Binding/HandleDictionary.cs new file mode 100644 index 0000000000..274e4100ad --- /dev/null +++ b/binding/Binding/HandleDictionary.cs @@ -0,0 +1,215 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Reflection; +using System.Threading; + +namespace SkiaSharp +{ + internal static class HandleDictionary + { + private static readonly Type IntPtrType = typeof (IntPtr); + private static readonly Type BoolType = typeof (bool); + +#if THROW_OBJECT_EXCEPTIONS + internal static readonly ConcurrentBag exceptions = new ConcurrentBag (); +#endif + + internal static readonly ConcurrentDictionary constructors = new ConcurrentDictionary (); + internal static readonly Dictionary instances = new Dictionary (); + + internal static readonly ReaderWriterLockSlim instancesLock = new ReaderWriterLockSlim (); + + /// + /// Retrieve the living instance if there is one, or null if not. + /// + /// The instance if it is alive, or null if there is none. + internal static bool GetInstance (IntPtr handle, out TSkiaObject instance) + where TSkiaObject : SKObject + { + if (handle == IntPtr.Zero) { + instance = null; + return false; + } + + instancesLock.EnterReadLock (); + try { + return GetInstanceNoLocks (handle, out instance); + } finally { + instancesLock.ExitReadLock (); + } + } + + /// + /// Retrieve or create an instance for the native handle. + /// + /// The instance, or null if the handle was null. + internal static TSkiaObject GetObject (IntPtr handle, bool owns = true, bool unrefExisting = true, bool refNew = false) + where TSkiaObject : SKObject + where TSkiaImplementation : SKObject, TSkiaObject + { + if (handle == IntPtr.Zero) + return null; + + instancesLock.EnterUpgradeableReadLock (); + try { + if (GetInstanceNoLocks (handle, out var instance)) { + // some object get automatically referenced on the native side, + // but managed code just has the same reference + if (unrefExisting && instance is ISKReferenceCounted refcnt) { +#if THROW_OBJECT_EXCEPTIONS + if (refcnt.GetReferenceCount () == 1) + throw new InvalidOperationException ( + $"About to unreference an object that has no references. " + + $"H: {handle.ToString ("x")} Type: {instance.GetType ()}"); +#endif + refcnt.SafeUnRef (); + } + + return instance; + } + + var type = typeof (TSkiaImplementation); + var constructor = constructors.GetOrAdd (type, t => GetConstructor (t)); + + // we don't need to go into a writable here as the object will do it in the Handle property + var obj = (TSkiaObject)constructor.Invoke (new object[] { handle, owns }); + if (refNew && obj is ISKReferenceCounted toRef) + toRef.SafeRef (); + return obj; + } finally { + instancesLock.ExitUpgradeableReadLock (); + } + + static ConstructorInfo GetConstructor (Type type) + { + var ctors = type.GetTypeInfo ().DeclaredConstructors; + + foreach (var ctor in ctors) { + var param = ctor.GetParameters (); + if (param.Length == 2 && param[0].ParameterType == IntPtrType && param[1].ParameterType == BoolType) + return ctor; + } + + throw new MissingMethodException ($"No constructor found for {type.FullName}.ctor(System.IntPtr, System.Boolean)"); + } + } + + /// + /// Retrieve the living instance if there is one, or null if not. This does not use locks. + /// + /// The instance if it is alive, or null if there is none. + private static bool GetInstanceNoLocks (IntPtr handle, out TSkiaObject instance) + where TSkiaObject : SKObject + { + if (instances.TryGetValue (handle, out var weak) && weak.IsAlive) { + if (weak.Target is TSkiaObject match) { + if (!match.IsDisposed) { + instance = match; + return true; + } +#if THROW_OBJECT_EXCEPTIONS + } else if (weak.Target is SKObject obj) { + if (!obj.IsDisposed && obj.OwnsHandle) { + throw new InvalidOperationException ( + $"A managed object exists for the handle, but is not the expected type. " + + $"H: {handle.ToString ("x")} Type: ({obj.GetType ()}, {typeof (TSkiaObject)})"); + } + } else if (weak.Target is object o) { + throw new InvalidOperationException ( + $"An unknown object exists for the handle when trying to fetch an instance. " + + $"H: {handle.ToString ("x")} Type: ({o.GetType ()}, {typeof (TSkiaObject)})"); +#endif + } + } + + instance = null; + return false; + } + + /// + /// Registers the specified instance with the dictionary. + /// + internal static void RegisterHandle (IntPtr handle, SKObject instance) + { + if (handle == IntPtr.Zero || instance == null) + return; + + SKObject objectToDispose = null; + + instancesLock.EnterWriteLock (); + try { + if (instances.TryGetValue (handle, out var oldValue) && oldValue.Target is SKObject obj && !obj.IsDisposed) { +#if THROW_OBJECT_EXCEPTIONS + if (obj.OwnsHandle) { + // a mostly recoverable error + // if there is a managed object, then maybe something happened and the native object is dead + throw new InvalidOperationException ( + $"A managed object already exists for the specified native object. " + + $"H: {handle.ToString ("x")} Type: ({obj.GetType ()}, {instance.GetType ()})"); + } +#endif + // this means the ownership was handed off to a native object, so clean up the managed side + objectToDispose = obj; + } + + instances[handle] = new WeakReference (instance); + } finally { + instancesLock.ExitWriteLock (); + } + + // dispose the object we just replaced + objectToDispose?.DisposeInternal (); + } + + /// + /// Removes the registered instance from the dictionary. + /// + internal static void DeregisterHandle (IntPtr handle, SKObject instance) + { + if (handle == IntPtr.Zero) + return; + + instancesLock.EnterWriteLock (); + try { + var existed = instances.TryGetValue (handle, out var weak); + if (existed && (!weak.IsAlive || weak.Target == instance)) { + instances.Remove (handle); + } else { +#if THROW_OBJECT_EXCEPTIONS + InvalidOperationException ex = null; + if (!existed) { + // the object may have been replaced + + if (!instance.IsDisposed) { + // recoverable error + // there was no object there, but we are still alive + ex = new InvalidOperationException ( + $"A managed object did not exist for the specified native object. " + + $"H: {handle.ToString ("x")} Type: {instance.GetType ()}"); + } + } else if (weak.Target is SKObject o && o != instance) { + // there was an object in the dictionary, but it was NOT this object + + if (!instance.IsDisposed) { + // recoverable error + // there was a new living object there, but we are still alive + ex = new InvalidOperationException ( + $"Trying to remove a different object with the same native handle. " + + $"H: {handle.ToString ("x")} Type: ({o.GetType ()}, {instance.GetType ()})"); + } + } + if (ex != null) { + if (instance.fromFinalizer) + exceptions.Add (ex); + else + throw ex; + } +#endif + } + } finally { + instancesLock.ExitWriteLock (); + } + } + } +} diff --git a/binding/Binding/SKBitmap.cs b/binding/Binding/SKBitmap.cs index 0627e16f6c..4832df501f 100644 --- a/binding/Binding/SKBitmap.cs +++ b/binding/Binding/SKBitmap.cs @@ -718,7 +718,10 @@ public bool PeekPixels (SKPixmap pixmap) if (pixmap == null) { throw new ArgumentNullException (nameof (pixmap)); } - return SkiaApi.sk_bitmap_peek_pixels (Handle, pixmap.Handle); + var result = SkiaApi.sk_bitmap_peek_pixels (Handle, pixmap.Handle); + if (result) + pixmap.pixelSource = this; + return result; } // Resize diff --git a/binding/Binding/SKColorSpace.cs b/binding/Binding/SKColorSpace.cs index 5942a55d9c..f61407ee48 100644 --- a/binding/Binding/SKColorSpace.cs +++ b/binding/Binding/SKColorSpace.cs @@ -164,7 +164,7 @@ public bool GetNumericalTransferFunction (out SKColorSpaceTransferFn fn) // *XyzD50 public SKMatrix44 ToXyzD50 () => - GetObject (SkiaApi.sk_colorspace_as_to_xyzd50 (Handle), false); + OwnedBy (GetObject (SkiaApi.sk_colorspace_as_to_xyzd50 (Handle), false), this); public bool ToXyzD50 (SKMatrix44 toXyzD50) { @@ -175,7 +175,7 @@ public bool ToXyzD50 (SKMatrix44 toXyzD50) } public SKMatrix44 FromXyzD50 () => - GetObject (SkiaApi.sk_colorspace_as_from_xyzd50 (Handle), false); + OwnedBy (GetObject (SkiaApi.sk_colorspace_as_from_xyzd50 (Handle), false), this); // diff --git a/binding/Binding/SKData.cs b/binding/Binding/SKData.cs index 422de86d64..395b39eb6b 100644 --- a/binding/Binding/SKData.cs +++ b/binding/Binding/SKData.cs @@ -251,6 +251,7 @@ public void SaveTo (Stream target) } finally { ArrayPool.Shared.Return (buffer); } + GC.KeepAlive (this); } // diff --git a/binding/Binding/SKDocument.cs b/binding/Binding/SKDocument.cs index 6dd2813ea6..ed0233d7a4 100644 --- a/binding/Binding/SKDocument.cs +++ b/binding/Binding/SKDocument.cs @@ -21,10 +21,10 @@ public void Abort () => SkiaApi.sk_document_abort (Handle); public SKCanvas BeginPage (float width, float height) => - GetObject (SkiaApi.sk_document_begin_page (Handle, width, height, null), false); + OwnedBy (GetObject (SkiaApi.sk_document_begin_page (Handle, width, height, null), false), this); public SKCanvas BeginPage (float width, float height, SKRect content) => - GetObject (SkiaApi.sk_document_begin_page (Handle, width, height, &content), false); + OwnedBy (GetObject (SkiaApi.sk_document_begin_page (Handle, width, height, &content), false), this); public void EndPage () => SkiaApi.sk_document_end_page (Handle); diff --git a/binding/Binding/SKFontStyle.cs b/binding/Binding/SKFontStyle.cs index 8b553d3685..d4b278a15f 100644 --- a/binding/Binding/SKFontStyle.cs +++ b/binding/Binding/SKFontStyle.cs @@ -4,6 +4,19 @@ namespace SkiaSharp { public class SKFontStyle : SKObject { + private static readonly Lazy normal; + private static readonly Lazy bold; + private static readonly Lazy italic; + private static readonly Lazy boldItalic; + + static SKFontStyle() + { + normal = new Lazy (() => new SKFontStyleStatic (SKFontStyleWeight.Normal, SKFontStyleWidth.Normal, SKFontStyleSlant.Upright)); + bold = new Lazy (() => new SKFontStyleStatic (SKFontStyleWeight.Bold, SKFontStyleWidth.Normal, SKFontStyleSlant.Upright)); + italic = new Lazy (() => new SKFontStyleStatic (SKFontStyleWeight.Normal, SKFontStyleWidth.Normal, SKFontStyleSlant.Italic)); + boldItalic = new Lazy (() => new SKFontStyleStatic (SKFontStyleWeight.Bold, SKFontStyleWidth.Normal, SKFontStyleSlant.Italic)); + } + [Preserve] internal SKFontStyle (IntPtr handle, bool owns) : base (handle, owns) @@ -37,12 +50,26 @@ protected override void DisposeNative () => public SKFontStyleSlant Slant => SkiaApi.sk_fontstyle_get_slant (Handle); - public static SKFontStyle Normal => new SKFontStyle (SKFontStyleWeight.Normal, SKFontStyleWidth.Normal, SKFontStyleSlant.Upright); + public static SKFontStyle Normal => normal.Value; - public static SKFontStyle Bold => new SKFontStyle (SKFontStyleWeight.Bold, SKFontStyleWidth.Normal, SKFontStyleSlant.Upright); + public static SKFontStyle Bold => bold.Value; - public static SKFontStyle Italic => new SKFontStyle (SKFontStyleWeight.Normal, SKFontStyleWidth.Normal, SKFontStyleSlant.Italic); + public static SKFontStyle Italic => italic.Value; - public static SKFontStyle BoldItalic => new SKFontStyle (SKFontStyleWeight.Bold, SKFontStyleWidth.Normal, SKFontStyleSlant.Italic); + public static SKFontStyle BoldItalic => boldItalic.Value; + + private sealed class SKFontStyleStatic : SKFontStyle + { + internal SKFontStyleStatic (SKFontStyleWeight weight, SKFontStyleWidth width, SKFontStyleSlant slant) + : base (weight, width, slant) + { + IgnorePublicDispose = true; + } + + protected override void Dispose (bool disposing) + { + // do not dispose + } + } } } diff --git a/binding/Binding/SKImage.cs b/binding/Binding/SKImage.cs index 3103d720c2..8cb70c11d2 100644 --- a/binding/Binding/SKImage.cs +++ b/binding/Binding/SKImage.cs @@ -529,7 +529,11 @@ public bool PeekPixels (SKPixmap pixmap) { if (pixmap == null) throw new ArgumentNullException (nameof (pixmap)); - return SkiaApi.sk_image_peek_pixels (Handle, pixmap.Handle); + + var result = SkiaApi.sk_image_peek_pixels (Handle, pixmap.Handle); + if (result) + pixmap.pixelSource = this; + return result; } public SKPixmap PeekPixels () diff --git a/binding/Binding/SKObject.cs b/binding/Binding/SKObject.cs index 7eb4048d93..8f7a01a1c6 100644 --- a/binding/Binding/SKObject.cs +++ b/binding/Binding/SKObject.cs @@ -1,7 +1,5 @@ using System; using System.Collections.Concurrent; -using System.Linq; -using System.Reflection; using System.Runtime.InteropServices; using System.Threading; @@ -9,21 +7,35 @@ namespace SkiaSharp { public abstract class SKObject : SKNativeObject { -#if THROW_OBJECT_EXCEPTIONS - internal static readonly ConcurrentBag exceptions = new ConcurrentBag (); -#endif + private readonly object locker = new object (); - internal static readonly ConcurrentDictionary constructors; - internal static readonly ConcurrentDictionary instances; + private ConcurrentDictionary ownedObjects; + private ConcurrentDictionary keepAliveObjects; - internal readonly ConcurrentDictionary ownedObjects = new ConcurrentDictionary (); - internal readonly ConcurrentDictionary keepAliveObjects = new ConcurrentDictionary (); + internal ConcurrentDictionary OwnedObjects { + get { + if (ownedObjects == null) { + lock (locker) { + ownedObjects ??= new ConcurrentDictionary (); + } + } + return ownedObjects; + } + } + + internal ConcurrentDictionary KeepAliveObjects { + get { + if (keepAliveObjects == null) { + lock (locker) { + keepAliveObjects ??= new ConcurrentDictionary (); + } + } + return keepAliveObjects; + } + } static SKObject () { - constructors = new ConcurrentDictionary (); - instances = new ConcurrentDictionary (); - SKColorSpace.EnsureStaticInstanceAreInitialized (); SKData.EnsureStaticInstanceAreInitialized (); SKFontManager.EnsureStaticInstanceAreInitialized (); @@ -54,11 +66,13 @@ protected set { protected override void DisposeManaged () { - foreach (var child in ownedObjects) { - child.Value.DisposeInternal (); + if (ownedObjects is ConcurrentDictionary dic) { + foreach (var child in dic) { + child.Value.DisposeInternal (); + } + dic.Clear (); } - ownedObjects.Clear (); - keepAliveObjects.Clear (); + KeepAliveObjects?.Clear (); } protected override void DisposeNative () @@ -70,7 +84,10 @@ protected override void DisposeNative () internal static TSkiaObject GetObject (IntPtr handle, bool owns = true, bool unrefExisting = true, bool refNew = false) where TSkiaObject : SKObject { - return GetObject (handle, owns, unrefExisting, refNew); + if (handle == IntPtr.Zero) + return null; + + return HandleDictionary.GetObject (handle, owns, unrefExisting, refNew); } internal static TSkiaObject GetObject (IntPtr handle, bool owns = true, bool unrefExisting = true, bool refNew = false) @@ -80,38 +97,7 @@ internal static TSkiaObject GetObject (IntPtr if (handle == IntPtr.Zero) return null; - if (GetInstance (handle, out var instance)) { - // some object get automatically referenced on the native side, - // but managed code just has the same reference - if (unrefExisting && instance is ISKReferenceCounted refcnt) { -#if THROW_OBJECT_EXCEPTIONS - if (refcnt.GetReferenceCount () == 1) - throw new InvalidOperationException ( - $"About to unreference an object that has no references. " + - $"H: {handle.ToString ("x")} Type: {instance.GetType ()}"); -#endif - refcnt.SafeUnRef (); - } - - return instance; - } - - var type = typeof (TSkiaImplementation); - var constructor = constructors.GetOrAdd (type, t => { - var ctor = type.GetTypeInfo ().DeclaredConstructors.FirstOrDefault (c => { - var parameters = c.GetParameters (); - return - parameters.Length == 2 && - parameters[0].ParameterType == typeof (IntPtr) && - parameters[1].ParameterType == typeof (bool); - }); - return ctor ?? throw new MissingMethodException ($"No constructor found for {type.FullName}.ctor(System.IntPtr, System.Boolean)"); - }); - - var obj = (TSkiaObject)constructor.Invoke (new object[] { handle, owns }); - if (refNew && obj is ISKReferenceCounted toRef) - toRef.SafeRef (); - return obj; + return HandleDictionary.GetObject (handle, owns, unrefExisting, refNew); } internal static void RegisterHandle (IntPtr handle, SKObject instance) @@ -119,24 +105,7 @@ internal static void RegisterHandle (IntPtr handle, SKObject instance) if (handle == IntPtr.Zero || instance == null) return; - var weak = new WeakReference (instance); - instances.AddOrUpdate (handle, weak, Update); - - WeakReference Update (IntPtr key, WeakReference oldValue) - { - if (oldValue.Target is SKObject obj && !obj.IsDisposed) { -#if THROW_OBJECT_EXCEPTIONS - if (obj.OwnsHandle) - throw new InvalidOperationException ( - $"A managed object already exists for the specified native object. " + - $"H: {handle.ToString ("x")} Type: ({obj.GetType ()}, {instance.GetType ()})"); -#endif - - obj.DisposeInternal (); - } - - return weak; - } + HandleDictionary.RegisterHandle (handle, instance); } internal static void DeregisterHandle (IntPtr handle, SKObject instance) @@ -144,60 +113,18 @@ internal static void DeregisterHandle (IntPtr handle, SKObject instance) if (handle == IntPtr.Zero) return; - var existed = instances.TryRemove (handle, out var weak); - if (existed && weak.Target is SKObject obj && !obj.IsDisposed) { - // Existing object is not disposed, so re-register it - RegisterHandle (handle, obj); - } - -#if THROW_OBJECT_EXCEPTIONS - InvalidOperationException ex = null; - if (!existed) { - // there was no object for the handle - ex = new InvalidOperationException ( - $"A managed object did not exist for the specified native object. " + - $"H: {handle.ToString ("x")} Type: {instance.GetType ()}"); - } else if (weak.Target is SKObject o && o != instance && !instance.IsDisposed) { - // there was a new living object there, but we are still alive - ex = new InvalidOperationException ( - $"Trying to remove a different object with the same native handle. " + - $"H: {handle.ToString ("x")} Type: ({o.GetType ()}, {instance.GetType ()})"); - } - if (ex != null) { - if (instance.fromFinalizer) - exceptions.Add (ex); - else - throw ex; - } -#endif + HandleDictionary.DeregisterHandle (handle, instance); } internal static bool GetInstance (IntPtr handle, out TSkiaObject instance) where TSkiaObject : SKObject { - if (instances.TryGetValue (handle, out var weak)) { - if (weak.Target is TSkiaObject match) { - if (!match.IsDisposed) { - instance = match; - return true; - } -#if THROW_OBJECT_EXCEPTIONS - } else if (weak.Target is SKObject obj) { - if (!obj.IsDisposed) { - throw new InvalidOperationException ( - $"A managed object exists for the handle, but is not the expected type. " + - $"H: {handle.ToString ("x")} Type: ({obj.GetType ()}, {typeof (TSkiaObject)})"); - } - } else if (weak.Target is object o) { - throw new InvalidOperationException ( - $"An unknown object exists for the handle when trying to fetch an instance. " + - $"H: {handle.ToString ("x")} Type: ({o.GetType ()}, {typeof (TSkiaObject)})"); -#endif - } + if (handle == IntPtr.Zero) { + instance = null; + return false; } - instance = null; - return false; + return HandleDictionary.GetInstance (handle, out instance); } // indicate that the ownership of this object is now in the hands of @@ -210,7 +137,19 @@ internal void RevokeOwnership (SKObject newOwner) if (newOwner == null) DisposeInternal (); else - newOwner.ownedObjects[Handle] = this; + newOwner.OwnedObjects[Handle] = this; + } + + // indicate that the child is controlled by the native code and + // the managed wrapper should be disposed when the owner is + internal static T OwnedBy (T child, SKObject owner) + where T : SKObject + { + if (child != null) { + owner.OwnedObjects[child.Handle] = child; + } + + return child; } // indicate that the child was created by the managed code and @@ -220,7 +159,7 @@ internal static T Owned (T owner, SKObject child) { if (child != null) { if (owner != null) - owner.ownedObjects[child.Handle] = child; + owner.OwnedObjects[child.Handle] = child; else child.Dispose (); } @@ -234,7 +173,7 @@ internal static T Referenced (T owner, SKObject child) where T : SKObject { if (child != null && owner != null) - owner.keepAliveObjects[child.Handle] = child; + owner.KeepAliveObjects[child.Handle] = child; return owner; } @@ -274,9 +213,7 @@ internal static T PtrToStructure (IntPtr intPtr, int index) public abstract class SKNativeObject : IDisposable { -#if THROW_OBJECT_EXCEPTIONS internal bool fromFinalizer = false; -#endif private int isDisposed = 0; @@ -293,9 +230,7 @@ internal SKNativeObject (IntPtr handle, bool ownsHandle) ~SKNativeObject () { -#if THROW_OBJECT_EXCEPTIONS fromFinalizer = true; -#endif Dispose (false); } @@ -340,7 +275,7 @@ public void Dispose () DisposeInternal (); } - protected void DisposeInternal () + protected internal void DisposeInternal () { Dispose (true); GC.SuppressFinalize (this); diff --git a/binding/Binding/SKPictureRecorder.cs b/binding/Binding/SKPictureRecorder.cs index f2f3662633..794896965b 100644 --- a/binding/Binding/SKPictureRecorder.cs +++ b/binding/Binding/SKPictureRecorder.cs @@ -26,7 +26,7 @@ protected override void DisposeNative () => public SKCanvas BeginRecording (SKRect cullRect) { - return GetObject (SkiaApi.sk_picture_recorder_begin_recording (Handle, &cullRect), false); + return OwnedBy (GetObject (SkiaApi.sk_picture_recorder_begin_recording (Handle, &cullRect), false), this); } public SKPicture EndRecording () @@ -39,6 +39,7 @@ public SKDrawable EndRecordingAsDrawable () return GetObject (SkiaApi.sk_picture_recorder_end_recording_as_drawable (Handle)); } - public SKCanvas RecordingCanvas => GetObject (SkiaApi.sk_picture_get_recording_canvas (Handle), false); + public SKCanvas RecordingCanvas => + OwnedBy (GetObject (SkiaApi.sk_picture_get_recording_canvas (Handle), false), this); } } diff --git a/binding/Binding/SKPixmap.cs b/binding/Binding/SKPixmap.cs index a8cdc1cf5d..8e4810b6db 100644 --- a/binding/Binding/SKPixmap.cs +++ b/binding/Binding/SKPixmap.cs @@ -8,6 +8,9 @@ public unsafe class SKPixmap : SKObject { private const string UnableToCreateInstanceMessage = "Unable to create a new SKPixmap instance."; + // this is not meant to be anything but a GC reference to keep the actual pixel data alive + internal SKObject pixelSource; + [Preserve] internal SKPixmap (IntPtr handle, bool owns) : base (handle, owns) @@ -50,11 +53,19 @@ protected override void Dispose (bool disposing) => protected override void DisposeNative () => SkiaApi.sk_pixmap_destructor (Handle); + protected override void DisposeManaged () + { + base.DisposeManaged (); + + pixelSource = null; + } + // Reset public void Reset () { SkiaApi.sk_pixmap_reset (Handle); + pixelSource = null; } [EditorBrowsable (EditorBrowsableState.Never)] @@ -68,6 +79,7 @@ public void Reset (SKImageInfo info, IntPtr addr, int rowBytes) { var cinfo = SKImageInfoNative.FromManaged (ref info); SkiaApi.sk_pixmap_reset_with_params (Handle, &cinfo, (void*)addr, (IntPtr)rowBytes); + pixelSource = null; } // properties diff --git a/binding/Binding/SKSurface.cs b/binding/Binding/SKSurface.cs index 4604e9f1e4..ccc04d1108 100644 --- a/binding/Binding/SKSurface.cs +++ b/binding/Binding/SKSurface.cs @@ -293,7 +293,7 @@ public static SKSurface CreateNull (int width, int height) => // public SKCanvas Canvas => - GetObject (SkiaApi.sk_surface_get_canvas (Handle), false, unrefExisting: false); + OwnedBy (GetObject (SkiaApi.sk_surface_get_canvas (Handle), false, unrefExisting: false), this); [EditorBrowsable (EditorBrowsableState.Never)] [Obsolete ("Use SurfaceProperties instead.")] @@ -308,7 +308,7 @@ public SKSurfaceProps SurfaceProps { } public SKSurfaceProperties SurfaceProperties => - GetObject (SkiaApi.sk_surface_get_props (Handle), false); + OwnedBy (GetObject (SkiaApi.sk_surface_get_props (Handle), false), this); public SKImage Snapshot () => GetObject (SkiaApi.sk_surface_new_image_snapshot (Handle)); @@ -338,7 +338,10 @@ public bool PeekPixels (SKPixmap pixmap) if (pixmap == null) throw new ArgumentNullException (nameof (pixmap)); - return SkiaApi.sk_surface_peek_pixels (Handle, pixmap.Handle); + var result = SkiaApi.sk_surface_peek_pixels (Handle, pixmap.Handle); + if (result) + pixmap.pixelSource = this; + return result; } public bool ReadPixels (SKImageInfo dstInfo, IntPtr dstPixels, int dstRowBytes, int srcX, int srcY) diff --git a/scripts/azure-pipelines.yml b/scripts/azure-pipelines.yml index e72ff27eac..c7a77e1c49 100644 --- a/scripts/azure-pipelines.yml +++ b/scripts/azure-pipelines.yml @@ -391,7 +391,6 @@ stages: name: tests_windows displayName: Tests (Windows) vmImage: $(VM_IMAGE_WINDOWS) - retryCount: 3 target: tests additionalArgs: --skipExternals="all" shouldPublish: false @@ -419,7 +418,6 @@ stages: name: tests_macos displayName: Tests (macOS) vmImage: $(VM_IMAGE_MAC) - retryCount: 3 target: tests additionalArgs: --skipExternals="all" shouldPublish: false @@ -446,7 +444,6 @@ stages: displayName: Tests (Linux) vmImage: $(VM_IMAGE_LINUX) packages: $(MANAGED_LINUX_PACKAGES) - retryCount: 3 target: tests additionalArgs: --skipExternals="all" shouldPublish: false diff --git a/tests/SkiaSharp.Desktop.Tests/SkiaSharp.Desktop.Tests.csproj b/tests/SkiaSharp.Desktop.Tests/SkiaSharp.Desktop.Tests.csproj index 8653d437de..893d701f59 100644 --- a/tests/SkiaSharp.Desktop.Tests/SkiaSharp.Desktop.Tests.csproj +++ b/tests/SkiaSharp.Desktop.Tests/SkiaSharp.Desktop.Tests.csproj @@ -75,7 +75,6 @@ - diff --git a/tests/SkiaSharp.NetCore.Tests/SkiaSharp.NetCore.Tests.csproj b/tests/SkiaSharp.NetCore.Tests/SkiaSharp.NetCore.Tests.csproj index c6e43d1e51..2bcbba8524 100644 --- a/tests/SkiaSharp.NetCore.Tests/SkiaSharp.NetCore.Tests.csproj +++ b/tests/SkiaSharp.NetCore.Tests/SkiaSharp.NetCore.Tests.csproj @@ -12,7 +12,6 @@ - @@ -20,10 +19,6 @@ - - - - diff --git a/tests/Tests/GarbageCleanupFixture.cs b/tests/Tests/GarbageCleanupFixture.cs index 4356d1707a..d4ab479a7d 100644 --- a/tests/Tests/GarbageCleanupFixture.cs +++ b/tests/Tests/GarbageCleanupFixture.cs @@ -1,8 +1,11 @@ using System; -using System.Collections.Concurrent; +using System.Collections.Generic; using System.Linq; +using SkiaSharp.Tests; using Xunit; +[assembly: AssemblyFixture(typeof(GarbageCleanupFixture))] + namespace SkiaSharp.Tests { public class GarbageCleanupFixture : IDisposable @@ -10,16 +13,17 @@ public class GarbageCleanupFixture : IDisposable private static readonly string[] StaticTypes = new[] { "SkiaSharp.SKData+SKDataStatic", "SkiaSharp.SKFontManager+SKFontManagerStatic", + "SkiaSharp.SKFontStyle+SKFontStyleStatic", "SkiaSharp.SKTypeface+SKTypefaceStatic", "SkiaSharp.SKColorSpace+SKColorSpaceStatic", }; public GarbageCleanupFixture() { - Assert.Empty(SKObject.constructors); - var aliveObjects = SKObject.instances.Values + Assert.Empty(HandleDictionary.constructors); + var aliveObjects = HandleDictionary.instances.Values .Select(o => o.Target) - .Where(IsExpectedToBeDead) + .Where(o => IsExpectedToBeDead(o, null)) .ToList(); Assert.Empty(aliveObjects); } @@ -29,16 +33,28 @@ public void Dispose() GC.Collect(); GC.WaitForPendingFinalizers(); + var staticObjects = HandleDictionary.instances.Values + .Select(o => o.Target) + .Where(o => !IsExpectedToBeDead(o, null)) + .Cast() + .ToList(); + var staticChildren = staticObjects + .SelectMany(o => o.OwnedObjects.Values) + .ToList(); + // make sure nothing is alive - var aliveObjects = SKObject.instances.Values + var aliveObjects = HandleDictionary.instances.Values .Select(o => o.Target) - .Where(IsExpectedToBeDead) + .Where(o => IsExpectedToBeDead(o, staticChildren)) + .Cast() .ToList(); + foreach (var o in staticChildren) + aliveObjects.Remove(o); Assert.Empty(aliveObjects); #if THROW_OBJECT_EXCEPTIONS // make sure all the exceptions are accounted for - var exceptions = SKObject.exceptions + var exceptions = HandleDictionary.exceptions .ToList(); Assert.Empty(exceptions); @@ -50,7 +66,7 @@ public void Dispose() #endif } - private bool IsExpectedToBeDead(object instance) + private bool IsExpectedToBeDead(object instance, IEnumerable exceptions) { if (instance == null) return false; diff --git a/tests/Tests/Properties/AssemblyInfo.cs b/tests/Tests/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..3080802efa --- /dev/null +++ b/tests/Tests/Properties/AssemblyInfo.cs @@ -0,0 +1,4 @@ +using SkiaSharp.Tests; +using Xunit; + +[assembly: TestFramework("SkiaSharp.Tests." + nameof(CustomTestFramework), "SkiaSharp.Tests")] diff --git a/tests/Tests/SKBitmapTest.cs b/tests/Tests/SKBitmapTest.cs index 75883fdde3..ff8047504f 100644 --- a/tests/Tests/SKBitmapTest.cs +++ b/tests/Tests/SKBitmapTest.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; @@ -494,50 +495,67 @@ public void AlphaMaskIsApplied() mask.FreeImage(); } - [Obsolete] - [SkippableFact(Skip = "This test takes a long time (~3mins), so ignore this most of the time.")] - public static void ImageScalingMultipleThreadsTest() + [SkippableTheory] + [InlineData(100, 1000)] + public static void ImageScalingMultipleThreadsTest(int numThreads, int numIterationsPerThread) { - const int numThreads = 100; - const int numIterationsPerThread = 1000; - var referenceFile = Path.Combine(PathToImages, "baboon.jpg"); var tasks = new List(); + var complete = false; + var exceptions = new ConcurrentBag(); + for (int i = 0; i < numThreads; i++) { var task = Task.Run(() => { - for (int j = 0; j < numIterationsPerThread; j++) + try + { + for (int j = 0; j < numIterationsPerThread && exceptions.IsEmpty; j++) + { + var imageData = ComputeThumbnail(referenceFile); + Assert.NotEmpty(imageData); + } + } + catch (Exception ex) { - var imageData = ComputeThumbnail(referenceFile); + exceptions.Add(ex); } }); tasks.Add(task); } + Task.Run(async () => + { + while (!complete && exceptions.IsEmpty) + { + GC.Collect(); + await Task.Delay(500); + } + }); + Task.WaitAll(tasks.ToArray()); - Console.WriteLine($"Test completed for {numThreads} tasks, {numIterationsPerThread} each."); - } + complete = true; - [Obsolete] - private static byte[] ComputeThumbnail(string fileName) - { - using (var ms = new MemoryStream()) - using (var bitmap = SKBitmap.Decode(fileName)) - using (var scaledBitmap = new SKBitmap(60, 40, bitmap.ColorType, bitmap.AlphaType)) + if (!exceptions.IsEmpty) + throw new AggregateException(exceptions); + + static byte[] ComputeThumbnail(string fileName) { - SKBitmap.Resize(scaledBitmap, bitmap, SKBitmapResizeMethod.Hamming); + var ms = new MemoryStream(); + var bitmap = SKBitmap.Decode(fileName); + var scaledBitmap = new SKBitmap(60, 40, bitmap.ColorType, bitmap.AlphaType); - using (var image = SKImage.FromBitmap(scaledBitmap)) - using (var data = image.Encode(SKEncodedImageFormat.Png, 80)) - { - data.SaveTo(ms); + bitmap.ScalePixels(scaledBitmap, SKFilterQuality.High); - return ms.ToArray(); - } + var image = SKImage.FromBitmap(scaledBitmap); + var data = image.Encode(SKEncodedImageFormat.Png, 80); + + data.SaveTo(ms); + + return ms.ToArray(); } } diff --git a/tests/Tests/SKCodecTest.cs b/tests/Tests/SKCodecTest.cs index 4880c97042..91d64d6b95 100644 --- a/tests/Tests/SKCodecTest.cs +++ b/tests/Tests/SKCodecTest.cs @@ -393,5 +393,46 @@ public void ReadOnlyStream () using (var bitmap = SKBitmap.Decode (nonSeekable)) Assert.NotNull (bitmap); } + + [SkippableFact] + public void CanReadManagedStream() + { + using (var stream = File.OpenRead(Path.Combine(PathToImages, "baboon.png"))) + using (var codec = SKCodec.Create(stream)) + Assert.NotNull(codec); + } + + [SkippableTheory] + [InlineData("CMYK.jpg")] + [InlineData("baboon.png")] + [InlineData("color-wheel.png")] + public void CanDecodePath(string image) + { + var path = Path.Combine(PathToImages, image); + + using var codec = SKCodec.Create(path); + Assert.NotNull(codec); + + Assert.Equal(SKCodecResult.Success, codec.GetPixels(out var pixels)); + Assert.NotEmpty(pixels); + } + + [SkippableTheory] + [InlineData("CMYK.jpg")] + [InlineData("baboon.png")] + [InlineData("color-wheel.png")] + public void CanDecodeData(string image) + { + var path = Path.Combine(PathToImages, image); + + using var data = SKData.Create(path); + Assert.NotNull(data); + + using var codec = SKCodec.Create(data); + Assert.NotNull(codec); + + Assert.Equal(SKCodecResult.Success, codec.GetPixels(out var pixels)); + Assert.NotEmpty(pixels); + } } } diff --git a/tests/Tests/SKObjectTest.cs b/tests/Tests/SKObjectTest.cs index f2af08dff1..77a55315e0 100644 --- a/tests/Tests/SKObjectTest.cs +++ b/tests/Tests/SKObjectTest.cs @@ -1,5 +1,8 @@ using System; +using System.Collections.Concurrent; using System.IO; +using System.Linq; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -7,20 +10,25 @@ namespace SkiaSharp.Tests { public class SKObjectTest : SKTest { + private static int nextPtr = 1000; + + private static IntPtr GetNextPtr() => + (IntPtr)Interlocked.Increment(ref nextPtr); + [SkippableFact] public void ConstructorsAreCached() { - var handle = (IntPtr)123; + var handle = GetNextPtr(); SKObject.GetObject(handle); - Assert.True(SKObject.constructors.ContainsKey(typeof(LifecycleObject))); + Assert.True(HandleDictionary.constructors.ContainsKey(typeof(LifecycleObject))); } [SkippableFact] public void CanInstantiateAbstractClassesWithImplementation() { - var handle = (IntPtr)444; + var handle = GetNextPtr(); Assert.Throws(() => SKObject.GetObject(handle)); @@ -51,30 +59,35 @@ public void SameHandleReturnsSameReferenceAndReleasesObject() { VerifyImmediateFinalizers(); - var handle = (IntPtr)234; + var handle = GetNextPtr(); TestConstruction(handle); CollectGarbage(); + // there should be nothing if the GC ran Assert.False(SKObject.GetInstance(handle, out var inst)); Assert.Null(inst); - void TestConstruction(IntPtr h) + static void TestConstruction(IntPtr h) { - LifecycleObject i = null; - - Assert.False(SKObject.GetInstance(h, out i)); + // make sure there is nothing + Assert.False(SKObject.GetInstance(h, out LifecycleObject i)); Assert.Null(i); + // get/create the object var first = SKObject.GetObject(h); + // get the same one Assert.True(SKObject.GetInstance(h, out i)); Assert.NotNull(i); + // compare Assert.Same(first, i); + // get/create the object var second = SKObject.GetObject(h); + // compare Assert.Same(first, second); } } @@ -84,19 +97,22 @@ public void ObjectsWithTheSameHandleButDoNotOwnTheirHandlesAreCreatedAndCollecte { VerifyImmediateFinalizers(); - var handle = (IntPtr)566; + var handle = GetNextPtr(); - Construct(); + Construct(handle); CollectGarbage(); + // they should be gone Assert.False(SKObject.GetInstance(handle, out _)); - void Construct() + static void Construct(IntPtr handle) { + // create two objects with the same handle var inst1 = new LifecycleObject(handle, false); var inst2 = new LifecycleObject(handle, false); + // they should never be the same Assert.NotSame(inst1, inst2); } } @@ -104,22 +120,29 @@ void Construct() [SkippableFact] public void ObjectsWithTheSameHandleButDoNotOwnTheirHandlesAreCreatedAndDisposedCorrectly() { - var handle = (IntPtr)567; + var handle = GetNextPtr(); - var inst = Construct(); + var inst = Construct(handle); CollectGarbage(); + // the second object is still alive Assert.True(SKObject.GetInstance(handle, out var obj)); Assert.Equal(2, obj.Value); Assert.Same(inst, obj); - LifecycleObject Construct() + static LifecycleObject Construct(IntPtr handle) { + // create two objects var inst1 = new LifecycleObject(handle, false) { Value = 1 }; var inst2 = new LifecycleObject(handle, false) { Value = 2 }; + // make sure thy are different and the first is disposed Assert.NotSame(inst1, inst2); + Assert.True(inst1.DestroyedManaged); + + // because the object does not own the handle, the native is untouched + Assert.False(inst1.DestroyedNative); return inst2; } @@ -128,13 +151,13 @@ LifecycleObject Construct() [SkippableFact] public void ObjectsWithTheSameHandleAndOwnTheirHandlesThrowInDebugBuildsButNotRelease() { - var handle = (IntPtr)568; + var handle = GetNextPtr(); var inst1 = new LifecycleObject(handle, true) { Value = 1 }; #if THROW_OBJECT_EXCEPTIONS var ex = Assert.Throws(() => new LifecycleObject(handle, true) { Value = 2 }); - Assert.Contains("H: " + handle.ToString("x") + " ", ex.Message); + Assert.Contains($"H: {handle.ToString("x")} ", ex.Message); #else var inst2 = new LifecycleObject(handle, true) { Value = 2 }; Assert.True(inst1.DestroyedNative); @@ -147,7 +170,7 @@ public void ObjectsWithTheSameHandleAndOwnTheirHandlesThrowInDebugBuildsButNotRe [SkippableFact] public void DisposeInvalidatesObject() { - var handle = (IntPtr)345; + var handle = GetNextPtr(); var obj = SKObject.GetObject(handle); @@ -163,7 +186,7 @@ public void DisposeInvalidatesObject() [SkippableFact] public void DisposeDoesNotInvalidateObjectIfItIsNotOwned() { - var handle = (IntPtr)345; + var handle = GetNextPtr(); var obj = SKObject.GetObject(handle, false); @@ -267,19 +290,21 @@ public async Task EnsureMultithreadingDoesNotThrow(int iterations) [SkippableFact] public void EnsureConcurrencyResultsInCorrectDeregistration() { - var handle = (IntPtr)446; + var handle = GetNextPtr(); var obj = new ImmediateRecreationObject(handle, true); Assert.Null(obj.NewInstance); - Assert.Equal(obj, SKObject.instances[handle]?.Target); + Assert.Equal(obj, HandleDictionary.instances[handle]?.Target); obj.Dispose(); Assert.True(SKObject.GetInstance(handle, out _)); var newObj = obj.NewInstance; - Assert.NotEqual(obj, SKObject.instances[handle]?.Target); - Assert.Equal(newObj, SKObject.instances[handle]?.Target); + var weakReference = HandleDictionary.instances[handle]; + Assert.True(weakReference.IsAlive); + Assert.NotEqual(obj, weakReference.Target); + Assert.Equal(newObj, weakReference.Target); newObj.Dispose(); Assert.False(SKObject.GetInstance(handle, out _)); @@ -305,5 +330,163 @@ protected override void DisposeNative() NewInstance = new ImmediateRecreationObject(Handle, false); } } + + [SkippableFact] + public async Task DelayedConstructionDoesNotCreateInvalidState() + { + var handle = GetNextPtr(); + + DelayedConstructionObject objFast = null; + DelayedConstructionObject objSlow = null; + + var order = new ConcurrentQueue(); + + var objFastStart = new AutoResetEvent(false); + var objFastDelay = new AutoResetEvent(false); + + var fast = Task.Run(() => + { + order.Enqueue(1); + + DelayedConstructionObject.ConstructionStartedEvent = objFastStart; + DelayedConstructionObject.ConstructionDelayEvent = objFastDelay; + objFast = SKObject.GetObject(handle); + order.Enqueue(4); + }); + + var slow = Task.Run(() => + { + order.Enqueue(1); + + objFastStart.WaitOne(); + order.Enqueue(2); + + var timer = new Timer(state => objFastDelay.Set(), null, 1000, Timeout.Infinite); + order.Enqueue(3); + + objSlow = SKObject.GetObject(handle); + order.Enqueue(5); + + timer.Dispose(objFastDelay); + }); + + await Task.WhenAll(new[] { fast, slow }); + + // make sure it was the right order + Assert.Equal(new[] { 1, 1, 2, 3, 4, 5 }, order); + + // make sure both were "created" and they are the same object + Assert.NotNull(objFast); + Assert.NotNull(objSlow); + Assert.Same(objFast, objSlow); + } + + [SkippableFact] + public async Task DelayedDestructionDoesNotCreateInvalidState() + { + var handle = GetNextPtr(); + + DelayedDestructionObject objFast = null; + DelayedDestructionObject objSlow = null; + + using var secondThreadStarter = new AutoResetEvent(false); + + var order = new ConcurrentQueue(); + + var fast = Task.Run(() => + { + order.Enqueue(1); + + objFast = SKObject.GetObject(handle); + objFast.DisposeDelayEvent = new AutoResetEvent(false); + + Assert.True(SKObject.GetInstance(handle, out var beforeDispose)); + Assert.Same(objFast, beforeDispose); + + order.Enqueue(2); + // start thread 2 + secondThreadStarter.Set(); + + objFast.Dispose(); + order.Enqueue(7); + }); + + var slow = Task.Run(() => + { + // wait for thread 1 + secondThreadStarter.WaitOne(); + + order.Enqueue(3); + // wait for the disposal to start + objFast.DisposeStartedEvent.WaitOne(); + order.Enqueue(4); + + Assert.False(SKObject.GetInstance(handle, out var beforeCreate)); + Assert.Null(beforeCreate); + + var directRef = HandleDictionary.instances[handle]; + Assert.Same(objFast, directRef.Target); + + order.Enqueue(5); + objSlow = SKObject.GetObject(handle); + order.Enqueue(6); + + // finish the disposal + objFast.DisposeDelayEvent.Set(); + }); + + await Task.WhenAll(new[] { fast, slow }); + + // make sure it was the right order + Assert.Equal(new[] { 1, 2, 3, 4, 5, 6, 7 }, order); + + // make sure both were "created" and they are NOT the same object + Assert.NotNull(objFast); + Assert.NotNull(objSlow); + Assert.NotSame(objFast, objSlow); + Assert.True(SKObject.GetInstance(handle, out var final)); + Assert.Same(objSlow, final); + } + + private class DelayedConstructionObject : SKObject + { + public static AutoResetEvent ConstructionStartedEvent; + public static AutoResetEvent ConstructionDelayEvent; + + public DelayedConstructionObject(IntPtr handle, bool owns) + : base(GetHandle(handle), owns) + { + } + + private static IntPtr GetHandle(IntPtr handle) + { + var started = Interlocked.Exchange(ref ConstructionStartedEvent, null); + var delay = Interlocked.Exchange(ref ConstructionDelayEvent, null); + + started?.Set(); + delay?.WaitOne(); + + return handle; + } + } + + private class DelayedDestructionObject : SKObject + { + public AutoResetEvent DisposeStartedEvent = new AutoResetEvent(false); + public AutoResetEvent DisposeDelayEvent; + + public DelayedDestructionObject(IntPtr handle, bool owns) + : base(handle, owns) + { + } + + protected override void DisposeManaged() + { + DisposeStartedEvent.Set(); + DisposeDelayEvent?.WaitOne(); + + base.DisposeManaged(); + } + } } } diff --git a/tests/Tests/SKTest.cs b/tests/Tests/SKTest.cs index cf4a8e4e3a..e811e73cd9 100644 --- a/tests/Tests/SKTest.cs +++ b/tests/Tests/SKTest.cs @@ -5,7 +5,7 @@ namespace SkiaSharp.Tests { - public abstract class SKTest : BaseTest, IAssemblyFixture + public abstract class SKTest : BaseTest { protected const float EPSILON = 0.0001f; protected const int PRECISION = 4; diff --git a/tests/Tests/Xunit/AssemblyFixtureAttribute.cs b/tests/Tests/Xunit/AssemblyFixtureAttribute.cs new file mode 100644 index 0000000000..22224646fb --- /dev/null +++ b/tests/Tests/Xunit/AssemblyFixtureAttribute.cs @@ -0,0 +1,15 @@ +using System; + +namespace SkiaSharp.Tests +{ + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public class AssemblyFixtureAttribute : Attribute + { + public AssemblyFixtureAttribute(Type fixtureType) + { + FixtureType = fixtureType; + } + + public Type FixtureType { get; private set; } + } +} diff --git a/tests/Tests/Xunit/CustomTestFramework.cs b/tests/Tests/Xunit/CustomTestFramework.cs new file mode 100644 index 0000000000..7aed1b74a3 --- /dev/null +++ b/tests/Tests/Xunit/CustomTestFramework.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace SkiaSharp.Tests +{ + public class CustomTestFramework : XunitTestFramework + { + public CustomTestFramework(IMessageSink messageSink) + : base(messageSink) + { + } + + protected override ITestFrameworkExecutor CreateExecutor(AssemblyName assemblyName) => + new Executor(assemblyName, SourceInformationProvider, DiagnosticMessageSink); + + public class Executor : XunitTestFrameworkExecutor + { + public Executor( + AssemblyName assemblyName, + ISourceInformationProvider sourceInformationProvider, + IMessageSink diagnosticMessageSink) + : base(assemblyName, sourceInformationProvider, diagnosticMessageSink) + { + } + + protected override async void RunTestCases( + IEnumerable testCases, + IMessageSink executionMessageSink, + ITestFrameworkExecutionOptions executionOptions) + { + using var assemblyRunner = new AssemblyRunner( + TestAssembly, + testCases, + DiagnosticMessageSink, + executionMessageSink, + executionOptions); + await assemblyRunner.RunAsync(); + } + } + + public class AssemblyRunner : XunitTestAssemblyRunner + { + private readonly Dictionary assemblyFixtureMappings = new Dictionary(); + + public AssemblyRunner( + ITestAssembly testAssembly, + IEnumerable testCases, + IMessageSink diagnosticMessageSink, + IMessageSink executionMessageSink, + ITestFrameworkExecutionOptions executionOptions) + : base(testAssembly, testCases, diagnosticMessageSink, executionMessageSink, executionOptions) + { + } + + protected override async Task AfterTestAssemblyStartingAsync() + { + // Let everything initialize + await base.AfterTestAssemblyStartingAsync(); + + // Go find all the AssemblyFixtureAttributes adorned on the test assembly + Aggregator.Run(() => + { + var fixturesAttrs = ((IReflectionAssemblyInfo)TestAssembly.Assembly) + .Assembly + .GetCustomAttributes(typeof(AssemblyFixtureAttribute), false) + .Cast() + .ToList(); + + // Instantiate all the fixtures + foreach (var fixtureAttr in fixturesAttrs) + assemblyFixtureMappings[fixtureAttr.FixtureType] = Activator.CreateInstance(fixtureAttr.FixtureType); + }); + } + + protected override Task BeforeTestAssemblyFinishedAsync() + { + // Make sure we clean up everybody who is disposable, and use Aggregator.Run to isolate Dispose failures + foreach (var disposable in assemblyFixtureMappings.Values.OfType()) + Aggregator.Run(disposable.Dispose); + + return base.BeforeTestAssemblyFinishedAsync(); + } + + protected override Task RunTestCollectionAsync( + IMessageBus messageBus, + ITestCollection testCollection, + IEnumerable testCases, + CancellationTokenSource cancellationTokenSource) + { + var fixture = new CollectionRunner( + assemblyFixtureMappings, + testCollection, + testCases, + DiagnosticMessageSink, + messageBus, + TestCaseOrderer, + new ExceptionAggregator(Aggregator), + cancellationTokenSource); + return fixture.RunAsync(); + } + } + + public class CollectionRunner : XunitTestCollectionRunner + { + private readonly Dictionary assemblyFixtureMappings; + private readonly IMessageSink diagnosticMessageSink; + + public CollectionRunner( + Dictionary assemblyFixtureMappings, + ITestCollection testCollection, + IEnumerable testCases, + IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + ITestCaseOrderer testCaseOrderer, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + : base(testCollection, testCases, diagnosticMessageSink, messageBus, testCaseOrderer, aggregator, cancellationTokenSource) + { + this.assemblyFixtureMappings = assemblyFixtureMappings; + this.diagnosticMessageSink = diagnosticMessageSink; + } + + protected override Task RunTestClassAsync(ITestClass testClass, IReflectionTypeInfo @class, IEnumerable testCases) + { + // Don't want to use .Concat + .ToDictionary because of the possibility of overriding types, + // so instead we'll just let collection fixtures override assembly fixtures. + var combinedFixtures = new Dictionary(assemblyFixtureMappings); + foreach (var kvp in CollectionFixtureMappings) + combinedFixtures[kvp.Key] = kvp.Value; + + // We've done everything we need, so let the built-in types do the rest of the heavy lifting + var runner = new XunitTestClassRunner( + testClass, + @class, + testCases, + diagnosticMessageSink, + MessageBus, + TestCaseOrderer, + new ExceptionAggregator(Aggregator), + CancellationTokenSource, + combinedFixtures); + return runner.RunAsync(); + } + } + } +}