Skip to content

Commit

Permalink
Fix TPC equality check inside subquery predicate (#35120)
Browse files Browse the repository at this point in the history
Fixes #35118
  • Loading branch information
roji authored Nov 25, 2024
1 parent 3f5dcef commit 3d0b86d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 19 deletions.
36 changes: 17 additions & 19 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
Expand Down Expand Up @@ -1573,7 +1574,7 @@ public void ApplyPredicate(SqlExpression sqlExpression)
Left: ColumnExpression leftColumn,
Right: SqlConstantExpression { Value: string s1 }
}
when GetTable(leftColumn) is TpcTablesExpression
when TryGetTable(leftColumn, out var table, out _) && table is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand All @@ -1596,7 +1597,7 @@ when GetTable(leftColumn) is TpcTablesExpression
Left: SqlConstantExpression { Value: string s2 },
Right: ColumnExpression rightColumn
}
when GetTable(rightColumn) is TpcTablesExpression
when TryGetTable(rightColumn, out var table, out _) && table is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand All @@ -1620,7 +1621,7 @@ when GetTable(rightColumn) is TpcTablesExpression
Item: ColumnExpression itemColumn,
Values: IReadOnlyList<SqlExpression> valueExpressions
}
when GetTable(itemColumn) is TpcTablesExpression
when TryGetTable(itemColumn, out var table, out _) && table is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand Down Expand Up @@ -2733,31 +2734,28 @@ public TableExpressionBase GetTable(ColumnExpression column)
/// <see cref="SelectExpression" /> based on its alias.
/// </summary>
public TableExpressionBase GetTable(ColumnExpression column, out int tableIndex)
{
for (var i = 0; i < _tables.Count; i++)
{
var table = _tables[i];
if (table.UnwrapJoin().Alias == column.TableAlias)
{
tableIndex = i;
return table;
}
}

throw new InvalidOperationException($"Table not found with alias '{column.TableAlias}'");
}
=> TryGetTable(column, out var table, out tableIndex)
? table
: throw new InvalidOperationException($"Table not found with alias '{column.TableAlias}'");

private bool ContainsReferencedTable(ColumnExpression column)
=> TryGetTable(column, out _, out _);

private bool TryGetTable(ColumnExpression column, [NotNullWhen(true)] out TableExpressionBase? table, out int tableIndex)
{
foreach (var table in Tables)
for (var i = 0; i < _tables.Count; i++)
{
var unwrappedTable = table.UnwrapJoin();
if (unwrappedTable.Alias == column.TableAlias)
var t = _tables[i];
if (t.UnwrapJoin().Alias == column.TableAlias)
{
table = t;
tableIndex = i;
return true;
}
}

table = null;
tableIndex = 0;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5317,6 +5317,14 @@ public override async Task Ternary_Null_StartsWith(bool async)
AssertSql();
}

public override async Task Column_access_inside_subquery_predicate(bool async)
{
// Uncorrelated subquery, not supported by Cosmos
await AssertTranslationFailed(() => base.Column_access_inside_subquery_predicate(async));

AssertSql();
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5890,4 +5890,11 @@ public virtual Task Ternary_Null_StartsWith(bool async)
async,
ss => ss.Set<Order>().OrderBy(x => x.OrderID).Select(x => x == null ? null : x.OrderID + ""),
x => x.StartsWith("1"));

[ConditionalTheory] // #35118
[MemberData(nameof(IsAsyncData))]
public virtual Task Column_access_inside_subquery_predicate(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => ss.Set<Order>().Where(o => c.CustomerID == "ALFKI").Any()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -7516,6 +7516,21 @@ ORDER BY [o].[OrderID]
""");
}

public override async Task Column_access_inside_subquery_predicate(bool async)
{
await base.Column_access_inside_subquery_predicate(async);

AssertSql(
"""
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE EXISTS (
SELECT 1
FROM [Orders] AS [o]
WHERE [c].[CustomerID] = N'ALFKI')
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit 3d0b86d

Please sign in to comment.