diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index 13f93b99bfbaec..fbfc5724b1fd82 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -119,7 +119,7 @@ private ServiceCallSite TryCreateOpenGeneric(Type serviceType, CallSiteChain cal if (serviceType.IsConstructedGenericType && _descriptorLookup.TryGetValue(serviceType.GetGenericTypeDefinition(), out var descriptor)) { - return TryCreateOpenGeneric(descriptor.Last, serviceType, callSiteChain, DefaultSlot); + return TryCreateOpenGeneric(descriptor.Last, serviceType, callSiteChain, DefaultSlot, true); } return null; @@ -165,7 +165,7 @@ private ServiceCallSite TryCreateEnumerable(Type serviceType, CallSiteChain call { var descriptor = _descriptors[i]; var callSite = TryCreateExact(descriptor, itemType, callSiteChain, slot) ?? - TryCreateOpenGeneric(descriptor, itemType, callSiteChain, slot); + TryCreateOpenGeneric(descriptor, itemType, callSiteChain, slot, false); if (callSite != null) { @@ -231,14 +231,28 @@ private ServiceCallSite TryCreateExact(ServiceDescriptor descriptor, Type servic return null; } - private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type serviceType, CallSiteChain callSiteChain, int slot) + private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type serviceType, CallSiteChain callSiteChain, int slot, bool throwOnConstraintViolation) { if (serviceType.IsConstructedGenericType && serviceType.GetGenericTypeDefinition() == descriptor.ServiceType) { Debug.Assert(descriptor.ImplementationType != null, "descriptor.ImplementationType != null"); var lifetime = new ResultCache(descriptor.Lifetime, serviceType, slot); - var closedType = descriptor.ImplementationType.MakeGenericType(serviceType.GenericTypeArguments); + Type closedType; + try + { + closedType = descriptor.ImplementationType.MakeGenericType(serviceType.GenericTypeArguments); + } + catch (ArgumentException ex) + { + if (throwOnConstraintViolation) + { + throw new InvalidOperationException(Resources.FormatGenericConstraintViolation(serviceType, descriptor.ImplementationType), ex); + } + + return null; + } + return CreateConstructorCallSite(lifetime, serviceType, closedType, callSiteChain); } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/DependencyInjectionSpecificationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/DependencyInjectionSpecificationTests.cs index beed0c8661e4b8..4c717eae8e7289 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/DependencyInjectionSpecificationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/DependencyInjectionSpecificationTests.cs @@ -578,6 +578,160 @@ public void OpenGenericServicesCanBeResolved() Assert.Same(singletonService, genericService.Value); } + [Fact] + public void ConstrainedOpenGenericServicesCanBeResolved() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>)); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ConstrainedFakeOpenGenericService<>)); + var poco = new PocoClass(); + collection.AddSingleton(poco); + collection.AddSingleton(); + var provider = CreateServiceProvider(collection); + // Act + var allServices = provider.GetServices>().ToList(); + var constrainedServices = provider.GetServices>().ToList(); + var singletonService = provider.GetService(); + // Assert + Assert.Equal(2, allServices.Count); + Assert.Same(poco, allServices[0].Value); + Assert.Same(poco, allServices[1].Value); + Assert.Equal(1, constrainedServices.Count); + Assert.Same(singletonService, constrainedServices[0].Value); + } + + [Fact] + public void ConstrainedOpenGenericServicesReturnsEmptyWithNoMatches() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ConstrainedFakeOpenGenericService<>)); + collection.AddSingleton(); + var provider = CreateServiceProvider(collection); + // Act + var constrainedServices = provider.GetServices>().ToList(); + // Assert + Assert.Equal(0, constrainedServices.Count); + } + + [Fact] + public void InterfaceConstrainedOpenGenericServicesCanBeResolved() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>)); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithInterfaceConstraint<>)); + var enumerableVal = new ClassImplementingIEnumerable(); + collection.AddSingleton(enumerableVal); + collection.AddSingleton(); + var provider = CreateServiceProvider(collection); + // Act + var allServices = provider.GetServices>().ToList(); + var constrainedServices = provider.GetServices>().ToList(); + var singletonService = provider.GetService(); + // Assert + Assert.Equal(2, allServices.Count); + Assert.Same(enumerableVal, allServices[0].Value); + Assert.Same(enumerableVal, allServices[1].Value); + Assert.Equal(1, constrainedServices.Count); + Assert.Same(singletonService, constrainedServices[0].Value); + } + + [Fact] + public void PublicNoArgCtorConstrainedOpenGenericServicesCanBeResolved() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNoConstraints<>)); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNewConstraint<>)); + var provider = CreateServiceProvider(collection); + // Act + var allServices = provider.GetServices>().ToList(); + var constrainedServices = provider.GetServices>().ToList(); + // Assert + Assert.Equal(2, allServices.Count); + Assert.Equal(1, constrainedServices.Count); + } + + [Fact] + public void ClassConstrainedOpenGenericServicesCanBeResolved() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNoConstraints<>)); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithClassConstraint<>)); + var provider = CreateServiceProvider(collection); + // Act + var allServices = provider.GetServices>().ToList(); + var constrainedServices = provider.GetServices>().ToList(); + // Assert + Assert.Equal(2, allServices.Count); + Assert.Equal(1, constrainedServices.Count); + } + + [Fact] + public void StructConstrainedOpenGenericServicesCanBeResolved() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNoConstraints<>)); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithStructConstraint<>)); + var provider = CreateServiceProvider(collection); + // Act + var allServices = provider.GetServices>().ToList(); + var constrainedServices = provider.GetServices>().ToList(); + // Assert + Assert.Equal(2, allServices.Count); + Assert.Equal(1, constrainedServices.Count); + } + + [Fact] + public void AbstractClassConstrainedOpenGenericServicesCanBeResolved() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>)); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithAbstractClassConstraint<>)); + var poco = new PocoClass(); + collection.AddSingleton(poco); + var classInheritingClassInheritingAbstractClass = new ClassInheritingClassInheritingAbstractClass(); + collection.AddSingleton(classInheritingClassInheritingAbstractClass); + var provider = CreateServiceProvider(collection); + // Act + var allServices = provider.GetServices>().ToList(); + var constrainedServices = provider.GetServices>().ToList(); + // Assert + Assert.Equal(2, allServices.Count); + Assert.Same(classInheritingClassInheritingAbstractClass, allServices[0].Value); + Assert.Same(classInheritingClassInheritingAbstractClass, allServices[1].Value); + Assert.Equal(1, constrainedServices.Count); + Assert.Same(poco, constrainedServices[0].Value); + } + + [Fact] + public void SelfReferencingConstrainedOpenGenericServicesCanBeResolved() + { + // Arrange + var collection = new TestServiceCollection(); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>)); + collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithSelfReferencingConstraint<>)); + var poco = new PocoClass(); + collection.AddSingleton(poco); + var selfComparable = new ClassImplementingIComparable(); + collection.AddSingleton(selfComparable); + var provider = CreateServiceProvider(collection); + // Act + var allServices = provider.GetServices>().ToList(); + var constrainedServices = provider.GetServices>().ToList(); + // Assert + Assert.Equal(2, allServices.Count); + Assert.Same(selfComparable, allServices[0].Value); + Assert.Same(selfComparable, allServices[1].Value); + Assert.Equal(1, constrainedServices.Count); + Assert.Same(poco, constrainedServices[0].Value); + } + [Fact] public void ClosedServicesPreferredOverOpenGenericServices() { diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/AbstractClass.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/AbstractClass.cs new file mode 100644 index 00000000000000..c6ccfd34870819 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/AbstractClass.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public abstract class AbstractClass + { + + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIComparable.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIComparable.cs new file mode 100644 index 00000000000000..02ffe2e9c7961b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIComparable.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassImplementingIComparable : IComparable + { + public int CompareTo(ClassImplementingIComparable other) => 0; + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIEnumerable.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIEnumerable.cs new file mode 100644 index 00000000000000..20fe4e2fc7ce63 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIEnumerable.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassImplementingIEnumerable : IEnumerable + { + public IEnumerator GetEnumerator() => throw new NotImplementedException(); + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassInheritingAbstractClass.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassInheritingAbstractClass.cs new file mode 100644 index 00000000000000..ec37f2b103a2c5 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassInheritingAbstractClass.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassInheritingAbstractClass : AbstractClass + { + + } + + public class ClassAlsoInheritingAbstractClass : AbstractClass + { + + } + + public class ClassInheritingClassInheritingAbstractClass : ClassInheritingAbstractClass + { + + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithAbstractClassConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithAbstractClassConstraint.cs new file mode 100644 index 00000000000000..e5519863348d66 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithAbstractClassConstraint.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassWithAbstractClassConstraint : IFakeOpenGenericService + where T : AbstractClass + { + public ClassWithAbstractClassConstraint(T value) => Value = value; + + public T Value { get; } + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithClassConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithClassConstraint.cs new file mode 100644 index 00000000000000..b180203a688a71 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithClassConstraint.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassWithClassConstraint : IFakeOpenGenericService + where T : class + { + public T Value { get; } = default; + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithInterfaceConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithInterfaceConstraint.cs new file mode 100644 index 00000000000000..efd2c9c6cef63d --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithInterfaceConstraint.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections; + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassWithInterfaceConstraint : IFakeOpenGenericService + where T : IEnumerable + { + public ClassWithInterfaceConstraint(T value) => Value = value; + + public T Value { get; } + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNewConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNewConstraint.cs new file mode 100644 index 00000000000000..143986cfa11200 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNewConstraint.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassWithNewConstraint : IFakeOpenGenericService + where T : new() + { + public T Value { get; } = new T(); + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNoConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNoConstraint.cs new file mode 100644 index 00000000000000..848e7551cd6a30 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNoConstraint.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassWithNoConstraints : IFakeOpenGenericService + { + public T Value { get; } = default; + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithSelfReferencingConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithSelfReferencingConstraint.cs new file mode 100644 index 00000000000000..0e464290229407 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithSelfReferencingConstraint.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassWithSelfReferencingConstraint : IFakeOpenGenericService + where T : IComparable + { + public ClassWithSelfReferencingConstraint(T value) => Value = value; + + public T Value { get; } + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithStructConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithStructConstraint.cs new file mode 100644 index 00000000000000..06355eb905ffed --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithStructConstraint.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ClassWithStructConstraint : IFakeOpenGenericService + where T : struct + { + public T Value { get; } = default; + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ConstrainedFakeOpenGenericService.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ConstrainedFakeOpenGenericService.cs new file mode 100644 index 00000000000000..940e3621a8772b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ConstrainedFakeOpenGenericService.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class ConstrainedFakeOpenGenericService : IFakeOpenGenericService + where TVal : PocoClass + { + public ConstrainedFakeOpenGenericService(TVal value) + { + Value = value; + } + public TVal Value { get; } + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/IFakeOpenGenericService.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/IFakeOpenGenericService.cs index 12dc7142557384..d4697a14b2e31d 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/IFakeOpenGenericService.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/IFakeOpenGenericService.cs @@ -4,7 +4,7 @@ namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes { - public interface IFakeOpenGenericService + public interface IFakeOpenGenericService { TValue Value { get; } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/ServiceLookup/CallSiteFactoryTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/ServiceLookup/CallSiteFactoryTest.cs index edb64427689ce2..6d7327283c9b59 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/ServiceLookup/CallSiteFactoryTest.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/ServiceLookup/CallSiteFactoryTest.cs @@ -112,6 +112,299 @@ public void CreateCallSite_UsesNullaryConstructorIfServicesCannotBeInjectedIntoO Assert.Empty(ctorCallSite.ParameterCallSites); } + [Fact] + public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyStructGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithStructConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var nonMatchingType = typeof(IFakeOpenGenericService); + // Assert + var ex = Assert.Throws(() => callSiteFactory(nonMatchingType)); + Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message); + } + + [Fact] + public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesStructGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithStructConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var matchingType = typeof(IFakeOpenGenericService); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + Assert.NotNull(matchingCallSite); + } + + [Fact] + public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyClassGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithClassConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var nonMatchingType = typeof(IFakeOpenGenericService); + // Assert + var ex = Assert.Throws(() => callSiteFactory(nonMatchingType)); + Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message); + } + + [Fact] + public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesClassGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithClassConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var matchingType = typeof(IFakeOpenGenericService); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + Assert.NotNull(matchingCallSite); + } + + [Fact] + public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyNewGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithNewConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var nonMatchingType = typeof(IFakeOpenGenericService); + // Assert + var ex = Assert.Throws(() => callSiteFactory(nonMatchingType)); + Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message); + } + + [Fact] + public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesNewGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithNewConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(TypeWithParameterlessPublicConstructor), new TypeWithParameterlessPublicConstructor())); + // Act + var matchingType = typeof(IFakeOpenGenericService); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + Assert.NotNull(matchingCallSite); + } + + [Fact] + public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyInterfaceGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithInterfaceConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var nonMatchingType = typeof(IFakeOpenGenericService); + // Assert + var ex = Assert.Throws(() => callSiteFactory(nonMatchingType)); + Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message); + } + + [Fact] + public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesInterfaceGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithInterfaceConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(string), "")); + // Act + var matchingType = typeof(IFakeOpenGenericService); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + Assert.NotNull(matchingCallSite); + } + + [Fact] + public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyAbstractClassGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithAbstractClassConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var nonMatchingType = typeof(IFakeOpenGenericService); + // Assert + var ex = Assert.Throws(() => callSiteFactory(nonMatchingType)); + Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message); + } + + [Fact] + public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesAbstractClassGenericConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithAbstractClassConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(ClassInheritingAbstractClass), new ClassInheritingAbstractClass())); + // Act + var matchingType = typeof(IFakeOpenGenericService); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + Assert.NotNull(matchingCallSite); + } + + [Fact] + public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfySelfReferencingConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithSelfReferencingConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var nonMatchingType = typeof(IFakeOpenGenericService); + // Assert + var ex = Assert.Throws(() => callSiteFactory(nonMatchingType)); + Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message); + } + + [Fact] + public void CreateCallSite_Throws_IfComplexClosedTypeDoesNotSatisfySelfReferencingConstraint() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithSelfReferencingConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor); + // Act + var nonMatchingType = typeof(IFakeOpenGenericService); + // Assert + var ex = Assert.Throws(() => callSiteFactory(nonMatchingType)); + Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message); + } + + [Fact] + public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesSelfReferencing() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var implementationType = typeof(ClassWithSelfReferencingConstraint<>); + var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(string), "")); + // Act + var matchingType = typeof(IFakeOpenGenericService); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + Assert.NotNull(matchingCallSite); + } + + [Fact] + public void CreateCallSite_ReturnsEmpty_IfClosedTypeSatisfiesBaseClassConstraintButRegisteredTypeNotExactMatch() + { + // Arrange + var classInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint); + var classInheritingAbstractClassDescriptor = new ServiceDescriptor(typeof(IFakeOpenGenericService), classInheritingAbstractClassImplementationType, ServiceLifetime.Transient); + var classAlsoInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint); + var classAlsoInheritingAbstractClassDescriptor = new ServiceDescriptor(typeof(IFakeOpenGenericService), classAlsoInheritingAbstractClassImplementationType, ServiceLifetime.Transient); + var classInheritingClassInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint); + var classInheritingClassInheritingAbstractClassDescriptor = new ServiceDescriptor(typeof(IFakeOpenGenericService), classInheritingClassInheritingAbstractClassImplementationType, ServiceLifetime.Transient); + var notMatchingServiceType = typeof(IFakeOpenGenericService); + var notMatchingType = typeof(FakeService); + var notMatchingDescriptor = new ServiceDescriptor(notMatchingServiceType, notMatchingType, ServiceLifetime.Transient); + + var callSiteFactory = GetCallSiteFactory(classInheritingAbstractClassDescriptor, classAlsoInheritingAbstractClassDescriptor, classInheritingClassInheritingAbstractClassDescriptor, notMatchingDescriptor); + // Act + var matchingType = typeof(IEnumerable>); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + var enumerableCall = Assert.IsType(matchingCallSite); + + Assert.Empty(enumerableCall.ServiceCallSites); + } + + [Fact] + public void CreateCallSite_ReturnsMatchingTypes_IfClosedTypeSatisfiesBaseClassConstraintAndRegisteredType() + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService); + var classInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint); + var classInheritingAbstractClassDescriptor = new ServiceDescriptor(serviceType, classInheritingAbstractClassImplementationType, ServiceLifetime.Transient); + var classAlsoInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint); + var classAlsoInheritingAbstractClassDescriptor = new ServiceDescriptor(serviceType, classAlsoInheritingAbstractClassImplementationType, ServiceLifetime.Transient); + var classInheritingClassInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint); + var classInheritingClassInheritingAbstractClassDescriptor = new ServiceDescriptor(serviceType, classInheritingClassInheritingAbstractClassImplementationType, ServiceLifetime.Transient); + var notMatchingServiceType = typeof(IFakeOpenGenericService); + var notMatchingType = typeof(FakeService); + var notMatchingDescriptor = new ServiceDescriptor(notMatchingServiceType, notMatchingType, ServiceLifetime.Transient); + + var descriptors = new[] + { + classInheritingAbstractClassDescriptor, + new ServiceDescriptor(typeof(ClassInheritingAbstractClass), new ClassInheritingAbstractClass()), + classAlsoInheritingAbstractClassDescriptor, + new ServiceDescriptor(typeof(ClassAlsoInheritingAbstractClass), new ClassAlsoInheritingAbstractClass()), + classInheritingClassInheritingAbstractClassDescriptor, + new ServiceDescriptor(typeof(ClassInheritingClassInheritingAbstractClass), new ClassInheritingClassInheritingAbstractClass()), + notMatchingDescriptor + }; + var callSiteFactory = GetCallSiteFactory(descriptors); + // Act + var matchingType = typeof(IEnumerable<>).MakeGenericType(serviceType); + var matchingCallSite = callSiteFactory(matchingType); + // Assert + var enumerableCall = Assert.IsType(matchingCallSite); + + var matchingTypes = new[] + { + classInheritingAbstractClassImplementationType, + classAlsoInheritingAbstractClassImplementationType, + classInheritingClassInheritingAbstractClassImplementationType + }; + Assert.Equal(matchingTypes.Length, enumerableCall.ServiceCallSites.Length); + Assert.Equal(matchingTypes, enumerableCall.ServiceCallSites.Select(scs => scs.ImplementationType).ToArray()); + } + + [Theory] + [InlineData(typeof(IFakeOpenGenericService), default(int), new[] { typeof(FakeOpenGenericService), typeof(ClassWithStructConstraint), typeof(ClassWithNewConstraint), typeof(ClassWithSelfReferencingConstraint) })] + [InlineData(typeof(IFakeOpenGenericService), "", new[] { typeof(FakeOpenGenericService), typeof(ClassWithClassConstraint), typeof(ClassWithInterfaceConstraint), typeof(ClassWithSelfReferencingConstraint) })] + [InlineData(typeof(IFakeOpenGenericService), new[] { 1, 2, 3 }, new[] { typeof(FakeOpenGenericService), typeof(ClassWithClassConstraint), typeof(ClassWithInterfaceConstraint) })] + public void CreateCallSite_ReturnsMatchingTypesThatMatchCorrectConstraints(Type closedServiceType, object value, Type[] matchingImplementationTypes) + { + // Arrange + var serviceType = typeof(IFakeOpenGenericService<>); + var noConstraintImplementationType = typeof(FakeOpenGenericService<>); + var noConstraintDescriptor = new ServiceDescriptor(serviceType, noConstraintImplementationType, ServiceLifetime.Transient); + var structImplementationType = typeof(ClassWithStructConstraint<>); + var structDescriptor = new ServiceDescriptor(serviceType, structImplementationType, ServiceLifetime.Transient); + var classImplementationType = typeof(ClassWithClassConstraint<>); + var classDescriptor = new ServiceDescriptor(serviceType, classImplementationType, ServiceLifetime.Transient); + var newImplementationType = typeof(ClassWithNewConstraint<>); + var newDescriptor = new ServiceDescriptor(serviceType, newImplementationType, ServiceLifetime.Transient); + var interfaceImplementationType = typeof(ClassWithInterfaceConstraint<>); + var interfaceDescriptor = new ServiceDescriptor(serviceType, interfaceImplementationType, ServiceLifetime.Transient); + var selfConstraintImplementationType = typeof(ClassWithSelfReferencingConstraint<>); + var selfConstraintDescriptor = new ServiceDescriptor(serviceType, selfConstraintImplementationType, ServiceLifetime.Transient); + var serviceValueType = closedServiceType.GenericTypeArguments[0]; + var serviceValueDescriptor = new ServiceDescriptor(serviceValueType, value); + var callSiteFactory = GetCallSiteFactory(noConstraintDescriptor, structDescriptor, classDescriptor, newDescriptor, interfaceDescriptor, selfConstraintDescriptor, serviceValueDescriptor); + var collectionType = typeof(IEnumerable<>).MakeGenericType(closedServiceType); + // Act + var callSite = callSiteFactory(collectionType); + // Assert + var enumerableCall = Assert.IsType(callSite); + Assert.Equal(matchingImplementationTypes.Length, enumerableCall.ServiceCallSites.Length); + Assert.Equal(matchingImplementationTypes, enumerableCall.ServiceCallSites.Select(scs => scs.ImplementationType).ToArray()); + } + public static TheoryData CreateCallSite_PicksConstructorWithTheMostNumberOfResolvedParametersData => new TheoryData, Type[]> {