diff --git a/src/LightInject.Tests/DecoratorTests.cs b/src/LightInject.Tests/DecoratorTests.cs index 407b9c4c..410934ca 100644 --- a/src/LightInject.Tests/DecoratorTests.cs +++ b/src/LightInject.Tests/DecoratorTests.cs @@ -202,9 +202,11 @@ public void GetInstance_DeferredDecorator_ReturnsDecoratedInstance() { var container = CreateContainer(); container.Register(); - var registration = new DecoratorRegistration(); - registration.CanDecorate = serviceRegistration => true; - registration.ImplementingTypeFactory = (factory, serviceRegistration) => typeof(FooDecorator); + var registration = new DecoratorRegistration + { + CanDecorate = serviceRegistration => true, + ImplementingTypeFactory = (factory, serviceRegistration) => typeof(FooDecorator) + }; container.Decorate(registration); var instance = container.GetInstance(); @@ -550,11 +552,18 @@ public void GetInstance_HalfClosedDecoratorWithMissingGenericArgument_DotNotAppl container.Decorate(typeof(IFoo<,>), typeof(HalfClosedOpenGenericFooDecorator<,>)); var instance = container.GetInstance>(); Assert.IsType>(instance); - } - - + } + [Fact] + public void GetInstance_DecoratorWithBaseGenericConstraint_AppliesDecorator() + { + var container = CreateContainer(); + container.Register, Foo>(); + container.Decorate(typeof(IFoo<>), typeof(FooDecoratorWithBarBaseConstraint<>)); + var instance = container.GetInstance>(); + Assert.IsType>(instance); + } private IFoo CreateFooWithDependency(IServiceFactory factory) { diff --git a/src/LightInject.Tests/SampleServices/Foo.cs b/src/LightInject.Tests/SampleServices/Foo.cs index ce073d20..780c38c5 100644 --- a/src/LightInject.Tests/SampleServices/Foo.cs +++ b/src/LightInject.Tests/SampleServices/Foo.cs @@ -575,6 +575,24 @@ public HalfClosedOpenGenericFooDecorator(IFoo foo) } } + public class BarBase + { + + } + + public class InheritedBar :BarBase + { + + } + + public class FooDecoratorWithBarBaseConstraint : IFoo where T:BarBase + { + public FooDecoratorWithBarBaseConstraint(IFoo foo) + { + + } + } + public class Foo { } diff --git a/src/LightInject.Tests/ServiceContainerTests.cs b/src/LightInject.Tests/ServiceContainerTests.cs index baa9208b..07d585d9 100644 --- a/src/LightInject.Tests/ServiceContainerTests.cs +++ b/src/LightInject.Tests/ServiceContainerTests.cs @@ -305,7 +305,7 @@ public void GetInstance_OneNamedService_ReturnsDefaultService() } [Fact] - public void issue_231() + public void Issue_231() { var container = CreateContainer(); container.Register("foo", new PerContainerLifetime()); @@ -315,7 +315,7 @@ public void issue_231() } [Fact] - public void issue_168() + public void Issue_168() { var serviceContainer = new ServiceContainer(); serviceContainer.Register("bar"); diff --git a/src/LightInject/LightInject.cs b/src/LightInject/LightInject.cs index 160ce965..d5214279 100644 --- a/src/LightInject/LightInject.cs +++ b/src/LightInject/LightInject.cs @@ -1095,8 +1095,7 @@ public static class RuntimeArgumentsLoader /// An array containing the runtime arguments supplied when resolving the service. public static object[] Load(object[] constants) { - object[] arguments = constants[constants.Length - 1] as object[]; - if (arguments == null) + if (!(constants[constants.Length - 1] is object[] arguments)) { return new object[] { }; } @@ -3183,15 +3182,13 @@ private TypeConstructionInfoBuilder CreateTypeConstructionInfoBuilder() private Delegate GetConstructorDependencyDelegate(Type type, string serviceName) { - Delegate dependencyDelegate; - GetConstructorDependencyFactories(type).TryGetValue(serviceName, out dependencyDelegate); + GetConstructorDependencyFactories(type).TryGetValue(serviceName, out Delegate dependencyDelegate); return dependencyDelegate; } private Delegate GetPropertyDependencyExpression(Type type, string serviceName) { - Delegate dependencyDelegate; - GetPropertyDependencyFactories(type).TryGetValue(serviceName, out dependencyDelegate); + GetPropertyDependencyFactories(type).TryGetValue(serviceName, out Delegate dependencyDelegate); return dependencyDelegate; } @@ -3293,9 +3290,8 @@ private Action CreateEmitMethodWrapper(Action emitter, Type private Action GetRegisteredEmitMethod(Type serviceType, string serviceName) { - Action emitMethod; var registrations = GetEmitMethods(serviceType); - registrations.TryGetValue(serviceName, out emitMethod); + registrations.TryGetValue(serviceName, out Action emitMethod); return emitMethod ?? CreateEmitMethodForUnknownService(serviceType, serviceName); } @@ -3864,8 +3860,7 @@ private ServiceRegistration GetOpenGenericServiceRegistration(Type openGenericSe return null; } - ServiceRegistration openGenericServiceRegistration; - services.TryGetValue(serviceName, out openGenericServiceRegistration); + services.TryGetValue(serviceName, out ServiceRegistration openGenericServiceRegistration); if (openGenericServiceRegistration == null && string.IsNullOrEmpty(serviceName) && services.Count == 1) { return services.First().Value; @@ -5165,8 +5160,10 @@ public ConstructionInfo Execute(Registration registration) } var implementingType = registration.ImplementingType; - var constructionInfo = new ConstructionInfo(); - constructionInfo.ImplementingType = implementingType; + var constructionInfo = new ConstructionInfo + { + ImplementingType = implementingType + }; constructionInfo.PropertyDependencies.AddRange(GetPropertyDependencies(implementingType)); constructionInfo.Constructor = constructorSelector.Execute(implementingType); constructionInfo.ConstructorDependencies.AddRange(GetConstructorDependencies(constructionInfo.Constructor)); @@ -5361,8 +5358,7 @@ public override int GetHashCode() /// The to compare with the current . 2 public override bool Equals(object obj) { - var other = obj as ServiceRegistration; - if (other == null) + if (!(obj is ServiceRegistration other)) { return false; } @@ -5632,8 +5628,7 @@ public object GetInstance(Func createInstance, Scope scope) /// public void Dispose() { - var disposable = singleton as IDisposable; - if (disposable != null) + if (singleton is IDisposable disposable) { disposable.Dispose(); } @@ -5654,8 +5649,7 @@ public class PerRequestLifeTime : ILifetime public object GetInstance(Func createInstance, Scope scope) { var instance = createInstance(); - var disposable = instance as IDisposable; - if (disposable != null) + if (instance is IDisposable disposable) { TrackInstance(scope, disposable); } @@ -5704,8 +5698,7 @@ public object GetInstance(Func createInstance, Scope scope) private static void RegisterForDisposal(Scope scope, object instance) { - var disposable = instance as IDisposable; - if (disposable != null) + if (instance is IDisposable disposable) { scope.TrackInstance(disposable); } @@ -5724,8 +5717,7 @@ private void OnScopeCompleted(object sender, EventArgs e) { var scope = (Scope)sender; scope.Completed -= OnScopeCompleted; - object removedInstance; - instances.TryRemove(scope, out removedInstance); + instances.TryRemove(scope, out object removedInstance); } }