Skip to content

Commit

Permalink
[release/8.0-staging] ServiceKey comparisons use Equals for matching (#…
Browse files Browse the repository at this point in the history
…96847)

* RemoveAllKeyed use Equals for matching

This brings RemoveAllKeyed in line with service resolution (ServiceProvider.GetRequiredKeyedService)

* Use equals for TryAdd and Replace

* Add packaging info to csproj

---------

Co-authored-by: Tommy Sørbråten <tommysor@gmail.com>
Co-authored-by: Steve Harter <steveharter@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent c192fba commit 11a4ff1
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ public static IServiceCollection RemoveAllKeyed(this IServiceCollection collecti
for (int i = collection.Count - 1; i >= 0; i--)
{
ServiceDescriptor? descriptor = collection[i];
if (descriptor.ServiceType == serviceType && descriptor.ServiceKey == serviceKey)
if (descriptor.ServiceType == serviceType && object.Equals(descriptor.ServiceKey, serviceKey))
{
collection.RemoveAt(i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static void TryAdd(
for (int i = 0; i < count; i++)
{
if (collection[i].ServiceType == descriptor.ServiceType
&& collection[i].ServiceKey == descriptor.ServiceKey)
&& object.Equals(collection[i].ServiceKey, descriptor.ServiceKey))
{
// Already added
return;
Expand Down Expand Up @@ -474,7 +474,7 @@ public static void TryAddEnumerable(
ServiceDescriptor service = services[i];
if (service.ServiceType == descriptor.ServiceType &&
service.GetImplementationType() == implementationType &&
service.ServiceKey == descriptor.ServiceKey)
object.Equals(service.ServiceKey, descriptor.ServiceKey))
{
// Already added
return;
Expand Down Expand Up @@ -532,7 +532,7 @@ public static IServiceCollection Replace(
int count = collection.Count;
for (int i = 0; i < count; i++)
{
if (collection[i].ServiceType == descriptor.ServiceType && collection[i].ServiceKey == descriptor.ServiceKey)
if (collection[i].ServiceType == descriptor.ServiceType && object.Equals(collection[i].ServiceKey, descriptor.ServiceKey))
{
collection.RemoveAt(i);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
<TargetFrameworks>$(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.1;netstandard2.0;$(NetFrameworkMinimum)</TargetFrameworks>
<EnableDefaultItems>true</EnableDefaultItems>
<IsPackable>true</IsPackable>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<ServicingVersion>1</ServicingVersion>
<PackageDescription>Abstractions for dependency injection.

Commonly Used Types:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ public static TheoryData TryAddImplementationTypeData
{ collection => collection.TryAddKeyedTransient<IFakeService, FakeService>("key-2"), serviceType, "key-2", implementationType, ServiceLifetime.Transient },
{ collection => collection.TryAddKeyedTransient<IFakeService>("key-3"), serviceType, "key-3", serviceType, ServiceLifetime.Transient },
{ collection => collection.TryAddKeyedTransient(implementationType, "key-4"), implementationType, "key-4", implementationType, ServiceLifetime.Transient },
{ collection => collection.TryAddKeyedTransient(implementationType, 9), implementationType, 9, implementationType, ServiceLifetime.Transient },

{ collection => collection.TryAddKeyedScoped(serviceType, "key-1", implementationType), serviceType, "key-1", implementationType, ServiceLifetime.Scoped },
{ collection => collection.TryAddKeyedScoped<IFakeService, FakeService>("key-2"), serviceType, "key-2", implementationType, ServiceLifetime.Scoped },
Expand Down Expand Up @@ -325,6 +326,40 @@ public void TryAddEnumerable_DoesNotAddDuplicate(
Assert.Equal(expectedLifetime, d.Lifetime);
}

[Fact]
public void TryAddEnumerable_DoesNotAddDuplicateWhenKeyIsInt()
{
// Arrange
var collection = new ServiceCollection();
var descriptor1 = ServiceDescriptor.KeyedTransient<IFakeService, FakeService>(1);
collection.TryAddEnumerable(descriptor1);
var descriptor2 = ServiceDescriptor.KeyedTransient<IFakeService, FakeService>(1);

// Act
collection.TryAddEnumerable(descriptor2);

// Assert
var d = Assert.Single(collection);
Assert.Same(descriptor1, d);
}

[Fact]
public void TryAddEnumerable_DoesNotAddDuplicateWhenKeyIsString()
{
// Arrange
var collection = new ServiceCollection();
var descriptor1 = ServiceDescriptor.KeyedTransient<IFakeService, FakeService>("service1");
collection.TryAddEnumerable(descriptor1);
var descriptor2 = ServiceDescriptor.KeyedTransient<IFakeService, FakeService>("service1");

// Act
collection.TryAddEnumerable(descriptor2);

// Assert
var d = Assert.Single(collection);
Assert.Same(descriptor1, d);
}

public static TheoryData TryAddEnumerableInvalidImplementationTypeData
{
get
Expand Down Expand Up @@ -412,6 +447,24 @@ public void Replace_ReplacesFirstServiceWithMatchingServiceType()
Assert.Equal(new[] { descriptor2, descriptor3 }, collection);
}

[Fact]
public void Replace_ReplacesFirstServiceWithMatchingServiceTypeWhenKeyIsInt()
{
// Arrange
var collection = new ServiceCollection();
var descriptor1 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient);
var descriptor2 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient);
collection.Add(descriptor1);
collection.Add(descriptor2);
var descriptor3 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Singleton);

// Act
collection.Replace(descriptor3);

// Assert
Assert.Equal(new[] { descriptor2, descriptor3 }, collection);
}

[Fact]
public void RemoveAll_RemovesAllServicesWithMatchingServiceType()
{
Expand All @@ -431,6 +484,44 @@ public void RemoveAll_RemovesAllServicesWithMatchingServiceType()
Assert.Equal(new[] { descriptor }, collection);
}

private enum ServiceKeyEnum { First, Second }

[Fact]
public void RemoveAll_RemovesAllMatchingServicesWhenKeyIsEnum()
{
var descriptor = new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.First, typeof(FakeService), ServiceLifetime.Transient);
var collection = new ServiceCollection
{
descriptor,
new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.Second, typeof(FakeService), ServiceLifetime.Transient),
new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.Second, typeof(FakeService), ServiceLifetime.Transient),
};

// Act
collection.RemoveAllKeyed<IFakeService>(ServiceKeyEnum.Second);

// Assert
Assert.Equal(new[] { descriptor }, collection);
}

[Fact]
public void RemoveAll_RemovesAllMatchingServicesWhenKeyIsInt()
{
var descriptor = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient);
var collection = new ServiceCollection
{
descriptor,
new ServiceDescriptor(typeof(IFakeService), 2, typeof(FakeService), ServiceLifetime.Transient),
new ServiceDescriptor(typeof(IFakeService), 2, typeof(FakeService), ServiceLifetime.Transient),
};

// Act
collection.RemoveAllKeyed<IFakeService>(2);

// Assert
Assert.Equal(new[] { descriptor }, collection);
}

public static TheoryData NullServiceKeyData
{
get
Expand Down

0 comments on commit 11a4ff1

Please sign in to comment.