From 3d0b86d07b3f1350e95422865b81ee3c7829a719 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 25 Nov 2024 10:48:02 +0100 Subject: [PATCH] Fix TPC equality check inside subquery predicate (#35120) Fixes #35118 --- .../Query/SqlExpressions/SelectExpression.cs | 36 +++++++++---------- .../NorthwindMiscellaneousQueryCosmosTest.cs | 8 +++++ .../NorthwindMiscellaneousQueryTestBase.cs | 7 ++++ ...orthwindMiscellaneousQuerySqlServerTest.cs | 15 ++++++++ 4 files changed, 47 insertions(+), 19 deletions(-) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index d3931c36b7b..073a8c1b190 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; @@ -1573,7 +1574,7 @@ public void ApplyPredicate(SqlExpression sqlExpression) Left: ColumnExpression leftColumn, Right: SqlConstantExpression { Value: string s1 } } - when GetTable(leftColumn) is TpcTablesExpression + when TryGetTable(leftColumn, out var table, out _) && table is TpcTablesExpression { DiscriminatorColumn: var discriminatorColumn, DiscriminatorValues: var discriminatorValues @@ -1596,7 +1597,7 @@ when GetTable(leftColumn) is TpcTablesExpression Left: SqlConstantExpression { Value: string s2 }, Right: ColumnExpression rightColumn } - when GetTable(rightColumn) is TpcTablesExpression + when TryGetTable(rightColumn, out var table, out _) && table is TpcTablesExpression { DiscriminatorColumn: var discriminatorColumn, DiscriminatorValues: var discriminatorValues @@ -1620,7 +1621,7 @@ when GetTable(rightColumn) is TpcTablesExpression Item: ColumnExpression itemColumn, Values: IReadOnlyList valueExpressions } - when GetTable(itemColumn) is TpcTablesExpression + when TryGetTable(itemColumn, out var table, out _) && table is TpcTablesExpression { DiscriminatorColumn: var discriminatorColumn, DiscriminatorValues: var discriminatorValues @@ -2733,31 +2734,28 @@ public TableExpressionBase GetTable(ColumnExpression column) /// based on its alias. /// public TableExpressionBase GetTable(ColumnExpression column, out int tableIndex) - { - for (var i = 0; i < _tables.Count; i++) - { - var table = _tables[i]; - if (table.UnwrapJoin().Alias == column.TableAlias) - { - tableIndex = i; - return table; - } - } - - throw new InvalidOperationException($"Table not found with alias '{column.TableAlias}'"); - } + => TryGetTable(column, out var table, out tableIndex) + ? table + : throw new InvalidOperationException($"Table not found with alias '{column.TableAlias}'"); private bool ContainsReferencedTable(ColumnExpression column) + => TryGetTable(column, out _, out _); + + private bool TryGetTable(ColumnExpression column, [NotNullWhen(true)] out TableExpressionBase? table, out int tableIndex) { - foreach (var table in Tables) + for (var i = 0; i < _tables.Count; i++) { - var unwrappedTable = table.UnwrapJoin(); - if (unwrappedTable.Alias == column.TableAlias) + var t = _tables[i]; + if (t.UnwrapJoin().Alias == column.TableAlias) { + table = t; + tableIndex = i; return true; } } + table = null; + tableIndex = 0; return false; } diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs index 2967ea165f3..5107d20efa5 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs @@ -5317,6 +5317,14 @@ public override async Task Ternary_Null_StartsWith(bool async) AssertSql(); } + public override async Task Column_access_inside_subquery_predicate(bool async) + { + // Uncorrelated subquery, not supported by Cosmos + await AssertTranslationFailed(() => base.Column_access_inside_subquery_predicate(async)); + + AssertSql(); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs index 44b9d0f51ce..25c8f9d5cfc 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs @@ -5890,4 +5890,11 @@ public virtual Task Ternary_Null_StartsWith(bool async) async, ss => ss.Set().OrderBy(x => x.OrderID).Select(x => x == null ? null : x.OrderID + ""), x => x.StartsWith("1")); + + [ConditionalTheory] // #35118 + [MemberData(nameof(IsAsyncData))] + public virtual Task Column_access_inside_subquery_predicate(bool async) + => AssertQuery( + async, + ss => ss.Set().Where(c => ss.Set().Where(o => c.CustomerID == "ALFKI").Any())); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs index d3b7b4a45bc..4eb5fa068e0 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs @@ -7516,6 +7516,21 @@ ORDER BY [o].[OrderID] """); } + public override async Task Column_access_inside_subquery_predicate(bool async) + { + await base.Column_access_inside_subquery_predicate(async); + + AssertSql( + """ +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE EXISTS ( + SELECT 1 + FROM [Orders] AS [o] + WHERE [c].[CustomerID] = N'ALFKI') +"""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);