Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/8.0] Fix support of FromKeyedServicesAttribute in ActivatorUtilities.CreateFactory #92144

Merged
merged 2 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static class ActivatorUtilities
#endif

private static readonly MethodInfo GetServiceInfo =
GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?>>((sp, t, r, c) => GetService(sp, t, r, c));
GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?, object?>>((sp, t, r, c, k) => GetService(sp, t, r, c, k));

/// <summary>
/// Instantiate a type with constructor arguments provided directly and/or from an <see cref="IServiceProvider"/>.
Expand Down Expand Up @@ -324,9 +324,9 @@ private static MethodInfo GetMethodInfo<T>(Expression<T> expr)
return mc.Method;
}

private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue)
private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue, object? key)
{
object? service = sp.GetService(type);
object? service = key == null ? sp.GetService(type) : GetKeyedService(sp, type, key);
if (service is null && !hasDefaultValue)
{
ThrowHelperUnableToResolveService(type, requiredBy);
Expand Down Expand Up @@ -361,10 +361,12 @@ private static BlockExpression BuildFactoryExpression(
}
else
{
var keyAttribute = (FromKeyedServicesAttribute?) Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
var parameterTypeExpression = new Expression[] { serviceProvider,
Expression.Constant(parameterType, typeof(Type)),
Expression.Constant(constructor.DeclaringType, typeof(Type)),
Expression.Constant(hasDefaultValue) };
Expression.Constant(hasDefaultValue),
Expression.Constant(keyAttribute?.Key) };
constructorArguments[i] = Expression.Call(GetServiceInfo, parameterTypeExpression);
}

Expand Down Expand Up @@ -435,10 +437,10 @@ private static ObjectFactory CreateFactoryReflection(
if (matchedArgCount == 0)
{
// All injected; use a fast path.
Type[] types = GetParameterTypes();
FactoryParameterContext[] parameters = GetFactoryParameterContext();
return useFixedValues ?
(serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, types, declaringType, serviceProvider) :
(serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, types, declaringType, serviceProvider);
(serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, parameters, declaringType, serviceProvider) :
(serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, parameters, declaringType, serviceProvider);
}

if (matchedArgCount == constructorParameters.Length)
Expand All @@ -456,16 +458,6 @@ ObjectFactory InvokeCanonical()
(serviceProvider, arguments) => ReflectionFactoryCanonicalFixed(invoker, parameters, declaringType, serviceProvider, arguments) :
(serviceProvider, arguments) => ReflectionFactoryCanonicalSpan(invoker, parameters, declaringType, serviceProvider, arguments);
}

Type[] GetParameterTypes()
{
Type[] types = new Type[constructorParameters.Length];
for (int i = 0; i < constructorParameters.Length; i++)
{
types[i] = constructorParameters[i].ParameterType;
}
return types;
}
#else
ParameterInfo[] constructorParameters = constructor.GetParameters();
if (constructorParameters.Length == 0)
Expand All @@ -484,8 +476,15 @@ FactoryParameterContext[] GetFactoryParameterContext()
for (int i = 0; i < constructorParameters.Length; i++)
{
ParameterInfo constructorParameter = constructorParameters[i];
FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?)
Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue);
parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1);
parameters[i] = new FactoryParameterContext(
constructorParameter.ParameterType,
hasDefaultValue,
defaultValue,
parameterMap[i] ?? -1,
attr?.Key);
}

return parameters;
Expand All @@ -495,18 +494,20 @@ FactoryParameterContext[] GetFactoryParameterContext()

private readonly struct FactoryParameterContext
{
public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex)
public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex, object? serviceKey)
{
ParameterType = parameterType;
HasDefaultValue = hasDefaultValue;
DefaultValue = defaultValue;
ArgumentIndex = argumentIndex;
ServiceKey = serviceKey;
}

public Type ParameterType { get; }
public bool HasDefaultValue { get; }
public object? DefaultValue { get; }
public int ArgumentIndex { get; }
public object? ServiceKey { get; }
}

private static void FindApplicableConstructor(
Expand Down Expand Up @@ -825,57 +826,57 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
#if NET8_0_OR_GREATER // Use the faster ConstructorInvoker which also has alloc-free APIs when <= 4 parameters.
private static object ReflectionFactoryServiceOnlyFixed(
ConstructorInvoker invoker,
Type[] parameterTypes,
FactoryParameterContext[] parameters,
Type declaringType,
IServiceProvider serviceProvider)
{
Debug.Assert(parameterTypes.Length >= 1 && parameterTypes.Length <= FixedArgumentThreshold);
Debug.Assert(parameters.Length >= 1 && parameters.Length <= FixedArgumentThreshold);
Debug.Assert(FixedArgumentThreshold == 4);

if (serviceProvider is null)
ThrowHelperArgumentNullExceptionServiceProvider();

switch (parameterTypes.Length)
switch (parameters.Length)
{
case 1:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey));

case 2:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey));

case 3:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false),
GetService(serviceProvider, parameterTypes[2], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey));

case 4:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false),
GetService(serviceProvider, parameterTypes[2], declaringType, false),
GetService(serviceProvider, parameterTypes[3], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey),
GetService(serviceProvider, parameters[3].ParameterType, declaringType, false, parameters[3].ServiceKey));
}

return null!;
}

private static object ReflectionFactoryServiceOnlySpan(
ConstructorInvoker invoker,
Type[] parameterTypes,
FactoryParameterContext[] parameters,
Type declaringType,
IServiceProvider serviceProvider)
{
if (serviceProvider is null)
ThrowHelperArgumentNullExceptionServiceProvider();

object?[] arguments = new object?[parameterTypes.Length];
for (int i = 0; i < parameterTypes.Length; i++)
object?[] arguments = new object?[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
arguments[i] = GetService(serviceProvider, parameterTypes[i], declaringType, false);
arguments[i] = GetService(serviceProvider, parameters[i].ParameterType, declaringType, false, parameters[i].ServiceKey);
}

return invoker.Invoke(arguments.AsSpan());
Expand Down Expand Up @@ -907,7 +908,8 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue);
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue);
case 2:
{
ref FactoryParameterContext parameter2 = ref parameters[1];
Expand All @@ -920,15 +922,17 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
: GetService(
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue);
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue);
}
case 3:
{
Expand All @@ -943,23 +947,26 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
: GetService(
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue,
((parameter3.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter3.ArgumentIndex]
: GetService(
serviceProvider,
parameter3.ParameterType,
declaringType,
parameter3.HasDefaultValue)) ?? parameter3.DefaultValue);
parameter3.HasDefaultValue,
parameter3.ServiceKey)) ?? parameter3.DefaultValue);
}
case 4:
{
Expand All @@ -975,31 +982,35 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
: GetService(
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue,
((parameter3.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter3.ArgumentIndex]
: GetService(
serviceProvider,
parameter3.ParameterType,
declaringType,
parameter3.HasDefaultValue)) ?? parameter3.DefaultValue,
parameter3.HasDefaultValue,
parameter3.ServiceKey)) ?? parameter3.DefaultValue,
((parameter4.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter4.ArgumentIndex]
: GetService(
serviceProvider,
parameter4.ParameterType,
declaringType,
parameter4.HasDefaultValue)) ?? parameter4.DefaultValue);
parameter4.HasDefaultValue,
parameter4.ServiceKey)) ?? parameter4.DefaultValue);
}

}
Expand Down Expand Up @@ -1028,7 +1039,8 @@ private static object ReflectionFactoryCanonicalSpan(
serviceProvider,
parameter.ParameterType,
declaringType,
parameter.HasDefaultValue)) ?? parameter.DefaultValue;
parameter.HasDefaultValue,
parameter.ServiceKey)) ?? parameter.DefaultValue;
}

return invoker.Invoke(constructorArguments.AsSpan());
Expand Down Expand Up @@ -1078,7 +1090,8 @@ private static object ReflectionFactoryCanonical(
serviceProvider,
parameter.ParameterType,
declaringType,
parameter.HasDefaultValue)) ?? parameter.DefaultValue;
parameter.HasDefaultValue,
parameter.ServiceKey)) ?? parameter.DefaultValue;
}

return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null);
Expand All @@ -1099,5 +1112,17 @@ public static void ClearCache(Type[]? _)
}
}
#endif

private static object? GetKeyedService(IServiceProvider provider, Type type, object? serviceKey)
{
ThrowHelper.ThrowIfNull(provider);

if (provider is IKeyedServiceProvider keyedServiceProvider)
{
return keyedServiceProvider.GetKeyedService(type, serviceKey);
}

throw new InvalidOperationException(SR.KeyedServicesNotSupported);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -476,5 +476,41 @@ public ServiceProviderAccessor(IServiceProvider serviceProvider)

public IServiceProvider ServiceProvider { get; }
}

[Fact]
public void SimpleServiceKeyedResolution()
{
// Arrange
var services = new ServiceCollection();
services.AddKeyedTransient<ISimpleService, SimpleService>("simple");
services.AddKeyedTransient<ISimpleService, AnotherSimpleService>("another");
services.AddTransient<SimpleParentWithDynamicKeyedService>();
var provider = CreateServiceProvider(services);
var sut = provider.GetService<SimpleParentWithDynamicKeyedService>();

// Act
var result = sut!.GetService("simple");

// Assert
Assert.True(result.GetType() == typeof(SimpleService));
}

public class SimpleParentWithDynamicKeyedService
{
private readonly IServiceProvider _serviceProvider;

public SimpleParentWithDynamicKeyedService(IServiceProvider serviceProvider)
{
_serviceProvider = serviceProvider;
}

public ISimpleService GetService(string name) => _serviceProvider.GetKeyedService<ISimpleService>(name)!;
}

public interface ISimpleService { }

public class SimpleService : ISimpleService { }

public class AnotherSimpleService : ISimpleService { }
}
}
Loading