diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs index 6a42373ccc77d4..3ba30724d97f2b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs @@ -36,7 +36,7 @@ public static class ActivatorUtilities #endif private static readonly MethodInfo GetServiceInfo = - GetMethodInfo>((sp, t, r, c) => GetService(sp, t, r, c)); + GetMethodInfo>((sp, t, r, c, k) => GetService(sp, t, r, c, k)); /// /// Instantiate a type with constructor arguments provided directly and/or from an . @@ -324,9 +324,9 @@ private static MethodInfo GetMethodInfo(Expression expr) return mc.Method; } - private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue) + private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue, object? key) { - object? service = sp.GetService(type); + object? service = key == null ? sp.GetService(type) : GetKeyedService(sp, type, key); if (service is null && !hasDefaultValue) { ThrowHelperUnableToResolveService(type, requiredBy); @@ -361,10 +361,12 @@ private static BlockExpression BuildFactoryExpression( } else { + var keyAttribute = (FromKeyedServicesAttribute?) Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false); var parameterTypeExpression = new Expression[] { serviceProvider, Expression.Constant(parameterType, typeof(Type)), Expression.Constant(constructor.DeclaringType, typeof(Type)), - Expression.Constant(hasDefaultValue) }; + Expression.Constant(hasDefaultValue), + Expression.Constant(keyAttribute?.Key) }; constructorArguments[i] = Expression.Call(GetServiceInfo, parameterTypeExpression); } @@ -435,10 +437,10 @@ private static ObjectFactory CreateFactoryReflection( if (matchedArgCount == 0) { // All injected; use a fast path. - Type[] types = GetParameterTypes(); + FactoryParameterContext[] parameters = GetFactoryParameterContext(); return useFixedValues ? - (serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, types, declaringType, serviceProvider) : - (serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, types, declaringType, serviceProvider); + (serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, parameters, declaringType, serviceProvider) : + (serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, parameters, declaringType, serviceProvider); } if (matchedArgCount == constructorParameters.Length) @@ -456,16 +458,6 @@ ObjectFactory InvokeCanonical() (serviceProvider, arguments) => ReflectionFactoryCanonicalFixed(invoker, parameters, declaringType, serviceProvider, arguments) : (serviceProvider, arguments) => ReflectionFactoryCanonicalSpan(invoker, parameters, declaringType, serviceProvider, arguments); } - - Type[] GetParameterTypes() - { - Type[] types = new Type[constructorParameters.Length]; - for (int i = 0; i < constructorParameters.Length; i++) - { - types[i] = constructorParameters[i].ParameterType; - } - return types; - } #else ParameterInfo[] constructorParameters = constructor.GetParameters(); if (constructorParameters.Length == 0) @@ -484,8 +476,15 @@ FactoryParameterContext[] GetFactoryParameterContext() for (int i = 0; i < constructorParameters.Length; i++) { ParameterInfo constructorParameter = constructorParameters[i]; + FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?) + Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false); bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue); - parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1); + parameters[i] = new FactoryParameterContext( + constructorParameter.ParameterType, + hasDefaultValue, + defaultValue, + parameterMap[i] ?? -1, + attr?.Key); } return parameters; @@ -495,18 +494,20 @@ FactoryParameterContext[] GetFactoryParameterContext() private readonly struct FactoryParameterContext { - public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex) + public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex, object? serviceKey) { ParameterType = parameterType; HasDefaultValue = hasDefaultValue; DefaultValue = defaultValue; ArgumentIndex = argumentIndex; + ServiceKey = serviceKey; } public Type ParameterType { get; } public bool HasDefaultValue { get; } public object? DefaultValue { get; } public int ArgumentIndex { get; } + public object? ServiceKey { get; } } private static void FindApplicableConstructor( @@ -825,39 +826,39 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments() #if NET8_0_OR_GREATER // Use the faster ConstructorInvoker which also has alloc-free APIs when <= 4 parameters. private static object ReflectionFactoryServiceOnlyFixed( ConstructorInvoker invoker, - Type[] parameterTypes, + FactoryParameterContext[] parameters, Type declaringType, IServiceProvider serviceProvider) { - Debug.Assert(parameterTypes.Length >= 1 && parameterTypes.Length <= FixedArgumentThreshold); + Debug.Assert(parameters.Length >= 1 && parameters.Length <= FixedArgumentThreshold); Debug.Assert(FixedArgumentThreshold == 4); if (serviceProvider is null) ThrowHelperArgumentNullExceptionServiceProvider(); - switch (parameterTypes.Length) + switch (parameters.Length) { case 1: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey)); case 2: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false), - GetService(serviceProvider, parameterTypes[1], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey), + GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey)); case 3: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false), - GetService(serviceProvider, parameterTypes[1], declaringType, false), - GetService(serviceProvider, parameterTypes[2], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey), + GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey), + GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey)); case 4: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false), - GetService(serviceProvider, parameterTypes[1], declaringType, false), - GetService(serviceProvider, parameterTypes[2], declaringType, false), - GetService(serviceProvider, parameterTypes[3], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey), + GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey), + GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey), + GetService(serviceProvider, parameters[3].ParameterType, declaringType, false, parameters[3].ServiceKey)); } return null!; @@ -865,17 +866,17 @@ private static object ReflectionFactoryServiceOnlyFixed( private static object ReflectionFactoryServiceOnlySpan( ConstructorInvoker invoker, - Type[] parameterTypes, + FactoryParameterContext[] parameters, Type declaringType, IServiceProvider serviceProvider) { if (serviceProvider is null) ThrowHelperArgumentNullExceptionServiceProvider(); - object?[] arguments = new object?[parameterTypes.Length]; - for (int i = 0; i < parameterTypes.Length; i++) + object?[] arguments = new object?[parameters.Length]; + for (int i = 0; i < parameters.Length; i++) { - arguments[i] = GetService(serviceProvider, parameterTypes[i], declaringType, false); + arguments[i] = GetService(serviceProvider, parameters[i].ParameterType, declaringType, false, parameters[i].ServiceKey); } return invoker.Invoke(arguments.AsSpan()); @@ -907,7 +908,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue); + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue); case 2: { ref FactoryParameterContext parameter2 = ref parameters[1]; @@ -920,7 +922,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue, ((parameter2.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter2.ArgumentIndex] @@ -928,7 +931,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter2.ParameterType, declaringType, - parameter2.HasDefaultValue)) ?? parameter2.DefaultValue); + parameter2.HasDefaultValue, + parameter2.ServiceKey)) ?? parameter2.DefaultValue); } case 3: { @@ -943,7 +947,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue, ((parameter2.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter2.ArgumentIndex] @@ -951,7 +956,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter2.ParameterType, declaringType, - parameter2.HasDefaultValue)) ?? parameter2.DefaultValue, + parameter2.HasDefaultValue, + parameter2.ServiceKey)) ?? parameter2.DefaultValue, ((parameter3.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter3.ArgumentIndex] @@ -959,7 +965,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter3.ParameterType, declaringType, - parameter3.HasDefaultValue)) ?? parameter3.DefaultValue); + parameter3.HasDefaultValue, + parameter3.ServiceKey)) ?? parameter3.DefaultValue); } case 4: { @@ -975,7 +982,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue, ((parameter2.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter2.ArgumentIndex] @@ -983,7 +991,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter2.ParameterType, declaringType, - parameter2.HasDefaultValue)) ?? parameter2.DefaultValue, + parameter2.HasDefaultValue, + parameter2.ServiceKey)) ?? parameter2.DefaultValue, ((parameter3.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter3.ArgumentIndex] @@ -991,7 +1000,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter3.ParameterType, declaringType, - parameter3.HasDefaultValue)) ?? parameter3.DefaultValue, + parameter3.HasDefaultValue, + parameter3.ServiceKey)) ?? parameter3.DefaultValue, ((parameter4.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter4.ArgumentIndex] @@ -999,7 +1009,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter4.ParameterType, declaringType, - parameter4.HasDefaultValue)) ?? parameter4.DefaultValue); + parameter4.HasDefaultValue, + parameter4.ServiceKey)) ?? parameter4.DefaultValue); } } @@ -1028,7 +1039,8 @@ private static object ReflectionFactoryCanonicalSpan( serviceProvider, parameter.ParameterType, declaringType, - parameter.HasDefaultValue)) ?? parameter.DefaultValue; + parameter.HasDefaultValue, + parameter.ServiceKey)) ?? parameter.DefaultValue; } return invoker.Invoke(constructorArguments.AsSpan()); @@ -1078,7 +1090,8 @@ private static object ReflectionFactoryCanonical( serviceProvider, parameter.ParameterType, declaringType, - parameter.HasDefaultValue)) ?? parameter.DefaultValue; + parameter.HasDefaultValue, + parameter.ServiceKey)) ?? parameter.DefaultValue; } return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null); @@ -1099,5 +1112,17 @@ public static void ClearCache(Type[]? _) } } #endif + + private static object? GetKeyedService(IServiceProvider provider, Type type, object? serviceKey) + { + ThrowHelper.ThrowIfNull(provider); + + if (provider is IKeyedServiceProvider keyedServiceProvider) + { + return keyedServiceProvider.GetKeyedService(type, serviceKey); + } + + throw new InvalidOperationException(SR.KeyedServicesNotSupported); + } } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs index 32112c3e5f7691..a0dc73d58f820b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs @@ -476,5 +476,41 @@ public ServiceProviderAccessor(IServiceProvider serviceProvider) public IServiceProvider ServiceProvider { get; } } + + [Fact] + public void SimpleServiceKeyedResolution() + { + // Arrange + var services = new ServiceCollection(); + services.AddKeyedTransient("simple"); + services.AddKeyedTransient("another"); + services.AddTransient(); + var provider = CreateServiceProvider(services); + var sut = provider.GetService(); + + // Act + var result = sut!.GetService("simple"); + + // Assert + Assert.True(result.GetType() == typeof(SimpleService)); + } + + public class SimpleParentWithDynamicKeyedService + { + private readonly IServiceProvider _serviceProvider; + + public SimpleParentWithDynamicKeyedService(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider; + } + + public ISimpleService GetService(string name) => _serviceProvider.GetKeyedService(name)!; + } + + public interface ISimpleService { } + + public class SimpleService : ISimpleService { } + + public class AnotherSimpleService : ISimpleService { } } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs index dda3cafa442eb0..f6e7c2f3a8eb7c 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs @@ -240,6 +240,100 @@ public void CreateFactory_CreatesFactoryMethod_5Types_5Injected() Assert.NotNull(item.Z); } + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] +#if NETCOREAPP + [InlineData(false)] +#endif + public void CreateFactory_CreatesFactoryMethod_KeyedParams(bool useDynamicCode) + { + var options = new RemoteInvokeOptions(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + var factory = ActivatorUtilities.CreateFactory(Type.EmptyTypes); + + var services = new ServiceCollection(); + services.AddSingleton(new A()); + services.AddKeyedSingleton("b", new B()); + services.AddKeyedSingleton("c", new C()); + using var provider = services.BuildServiceProvider(); + ClassWithAKeyedBKeyedC item = factory(provider, null); + + Assert.IsType>(factory); + Assert.NotNull(item.A); + Assert.NotNull(item.B); + Assert.NotNull(item.C); + }, options); + } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] +#if NETCOREAPP + [InlineData(false)] +#endif + public void CreateFactory_CreatesFactoryMethod_KeyedParams_5Types(bool useDynamicCode) + { + var options = new RemoteInvokeOptions(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + var factory = ActivatorUtilities.CreateFactory(Type.EmptyTypes); + + var services = new ServiceCollection(); + services.AddSingleton(new A()); + services.AddKeyedSingleton("b", new B()); + services.AddKeyedSingleton("c", new C()); + services.AddSingleton(new S()); + services.AddSingleton(new Z()); + using var provider = services.BuildServiceProvider(); + ClassWithAKeyedBKeyedCSZ item = factory(provider, null); + + Assert.IsType>(factory); + Assert.NotNull(item.A); + Assert.NotNull(item.B); + Assert.NotNull(item.C); + }, options); + } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] +#if NETCOREAPP + [InlineData(false)] +#endif + public void CreateFactory_CreatesFactoryMethod_KeyedParams_1Injected(bool useDynamicCode) + { + var options = new RemoteInvokeOptions(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + var factory = ActivatorUtilities.CreateFactory(new Type[] { typeof(A) }); + + var services = new ServiceCollection(); + services.AddKeyedSingleton("b", new B()); + services.AddKeyedSingleton("c", new C()); + using var provider = services.BuildServiceProvider(); + ClassWithAKeyedBKeyedC item = factory(provider, new object?[] { new A() }); + + Assert.IsType>(factory); + Assert.NotNull(item.A); + Assert.NotNull(item.B); + Assert.NotNull(item.C); + }, options); + } + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] [InlineData(true)] #if NETCOREAPP @@ -527,6 +621,13 @@ internal class C { } internal class S { } internal class Z { } + internal class ClassWithAKeyedBKeyedC : ClassWithABC + { + public ClassWithAKeyedBKeyedC(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c) + : base(a, b, c) + { } + } + internal class ClassWithABCS : ClassWithABC { public S S { get; } @@ -540,6 +641,13 @@ internal class ClassWithABCSZ : ClassWithABCS public ClassWithABCSZ(A a, B b, C c, S s, Z z) : base(a, b, c, s) { Z = z; } } + internal class ClassWithAKeyedBKeyedCSZ : ClassWithABCSZ + { + public ClassWithAKeyedBKeyedCSZ(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c, S s, Z z) + : base(a, b, c, s, z) + { } + } + internal class ClassWithABC_FirstConstructorWithAttribute : ClassWithABC { [ActivatorUtilitiesConstructor]