Skip to content

Commit

Permalink
Fix Contains on ImmutableArray (#35247)
Browse files Browse the repository at this point in the history
  • Loading branch information
cincuranet authored Dec 2, 2024
1 parent b706c02 commit af19b40
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,16 @@ private Expression TryConvertCollectionContainsToQueryableContains(MethodCallExp

var sourceType = methodCallExpression.Method.DeclaringType!.GetGenericArguments()[0];

var objectExpression = methodCallExpression.Object!.Type.IsValueType
? Expression.Convert(methodCallExpression.Object!, typeof(IEnumerable<>).MakeGenericType(sourceType))
: methodCallExpression.Object!;

return VisitMethodCall(
Expression.Call(
QueryableMethods.Contains.MakeGenericMethod(sourceType),
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
methodCallExpression.Object!),
objectExpression),
methodCallExpression.Arguments[0]));
}

Expand Down
13 changes: 12 additions & 1 deletion src/EFCore/Query/QueryRootProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

private Expression VisitQueryRootCandidate(Expression expression, Type elementClrType)
{
switch (expression)
var candidateExpression = expression;

// In case the collection was value type, in order to call methods like AsQueryable,
// we need to convert it to IEnumerable<T> which requires boxing.
// We do that with Convert expression which we need to unwrap here.
if (expression is UnaryExpression { NodeType: ExpressionType.Convert } convertExpression
&& convertExpression.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
candidateExpression = convertExpression.Operand;
}

switch (candidateExpression)
{
// An array containing only constants is represented as a ConstantExpression with the array as the value.
// Convert that into a NewArrayExpression for use with InlineQueryRootExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,30 @@ WHERE ARRAY_CONTAINS(@ints, c["Int"])
"""
@ints='[10,999]'
SELECT VALUE c
FROM root c
WHERE NOT(ARRAY_CONTAINS(@ints, c["Int"]))
""");
});

public override Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
=> CosmosTestHelpers.Instance.NoSyncTest(
async, async a =>
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(a);

AssertSql(
"""
@ints='[10,999]'
SELECT VALUE c
FROM root c
WHERE ARRAY_CONTAINS(@ints, c["Int"])
""",
//
"""
@ints='[10,999]'
SELECT VALUE c
FROM root c
WHERE NOT(ARRAY_CONTAINS(@ints, c["Int"]))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;

namespace Microsoft.EntityFrameworkCore.Query;

public abstract class PrimitiveCollectionsQueryTestBase<TFixture>(TFixture fixture) : QueryTestBase<TFixture>(fixture)
Expand Down Expand Up @@ -363,6 +365,20 @@ await AssertQuery(
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
var ints = ImmutableArray.Create([10, 999]);

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => ints.Contains(c.Int)));
await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,24 @@ WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""",
//
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,34 @@ FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
"""
@ints='[10,999]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,34 @@ FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
"""
@ints='[10,999]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,34 @@ FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
"""
@ints='[10,999]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,34 @@ FROM json_each(@ints) AS "i"
"""
@ints='[10,999]' (Size = 8)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."NullableWrappedId", "p"."NullableWrappedIdWithNullableComparer", "p"."String", "p"."Strings", "p"."WrappedId"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" NOT IN (
SELECT "i"."value"
FROM json_each(@ints) AS "i"
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 8)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."NullableWrappedId", "p"."NullableWrappedIdWithNullableComparer", "p"."String", "p"."Strings", "p"."WrappedId"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@ints) AS "i"
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 8)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."NullableWrappedId", "p"."NullableWrappedIdWithNullableComparer", "p"."String", "p"."Strings", "p"."WrappedId"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" NOT IN (
Expand Down

0 comments on commit af19b40

Please sign in to comment.