Skip to content

Commit

Permalink
Lambda and local function suppressions (#2689)
Browse files Browse the repository at this point in the history
This adds support for RequiresUnreferencedCode and UnconditionalSuppressMessage on lambdas and local functions, relying on heuristics and knowledge of compiler implementation details to detect lambdas and local functions.

This approach scans the code for IL references to lambdas and local functions, which has some limitations.

- Unused local functions aren't referenced by the containing method, so warnings from these will not be suppressed by suppressions on the containing method. Lambdas don't seem to have this problem, because they contain a reference to the generated method as part of the delegate conversion.
- The IL doesn't in general contain enough information to determine the nesting of the scopes of lambdas and local functions, so we make no attempt to do this. We only try to determine to which user method a lambda or local function belongs. So suppressions on a lambda or local function will not silence warnings from a nested lambda or local function in the same scope.

This also adds warnings for reflection access to compiler-generated state machine members, and to lambdas or local functions. For these, the analyzer makes no attempt to determine what the actual IL corresponding to the user code will be, so it produces fewer warnings. The linker will warn for what is actually in IL.
  • Loading branch information
sbomer authored Mar 30, 2022
1 parent da3c743 commit cb11422
Show file tree
Hide file tree
Showing 9 changed files with 1,033 additions and 137 deletions.
9 changes: 7 additions & 2 deletions src/ILLink.RoslynAnalyzer/RequiresISymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ private static bool IsInRequiresScope (this ISymbol member, string requiresAttri
if (member is ITypeSymbol)
return false;

if (member.HasAttribute (requiresAttribute) && !member.IsStaticConstructor ())
return true;
while (true) {
if (member.HasAttribute (requiresAttribute) && !member.IsStaticConstructor ())
return true;
if (member.ContainingSymbol is not IMethodSymbol method)
break;
member = method;
}

if (member.ContainingType is ITypeSymbol containingType && containingType.HasAttribute (requiresAttribute))
return true;
Expand Down
14 changes: 8 additions & 6 deletions src/linker/Linker.Steps/MarkStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2901,13 +2901,15 @@ internal bool ShouldSuppressAnalysisWarningsForRequiresUnreferencedCode ()
if (originMember is not IMemberDefinition member)
return false;

MethodDefinition? userDefinedMethod = Context.CompilerGeneratedState.GetUserDefinedMethodForCompilerGeneratedMember (member);
if (userDefinedMethod == null)
return false;

Debug.Assert (userDefinedMethod != originMember);
MethodDefinition? owningMethod;
while (Context.CompilerGeneratedState.TryGetOwningMethodForCompilerGeneratedMember (member, out owningMethod)) {
Debug.Assert (owningMethod != member);
if (Annotations.IsMethodInRequiresUnreferencedCodeScope (owningMethod))
return true;
member = owningMethod;
}

return Annotations.IsMethodInRequiresUnreferencedCodeScope (userDefinedMethod);
return false;
}

internal void CheckAndReportRequiresUnreferencedCode (MethodDefinition method)
Expand Down
23 changes: 13 additions & 10 deletions src/linker/Linker/Annotations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -593,18 +593,21 @@ public bool TryGetLinkerAttribute<T> (IMemberDefinition member, [NotNullWhen (re
/// </summary>
/// <remarks>Unlike <see cref="IsMethodInRequiresUnreferencedCodeScope(MethodDefinition)"/> only static methods
/// and .ctors are reported as requiring unreferenced code when the declaring type has RUC on it.</remarks>
internal bool DoesMethodRequireUnreferencedCode (MethodDefinition method, [NotNullWhen (returnValue: true)] out RequiresUnreferencedCodeAttribute? attribute)
internal bool DoesMethodRequireUnreferencedCode (MethodDefinition originalMethod, [NotNullWhen (returnValue: true)] out RequiresUnreferencedCodeAttribute? attribute)
{
if (method.IsStaticConstructor ()) {
attribute = null;
return false;
}
if (TryGetLinkerAttribute (method, out attribute))
return true;
MethodDefinition? method = originalMethod;
do {
if (method.IsStaticConstructor ()) {
attribute = null;
return false;
}
if (TryGetLinkerAttribute (method, out attribute))
return true;

if ((method.IsStatic || method.IsConstructor) && method.DeclaringType is not null &&
TryGetLinkerAttribute (method.DeclaringType, out attribute))
return true;
if ((method.IsStatic || method.IsConstructor) && method.DeclaringType is not null &&
TryGetLinkerAttribute (method.DeclaringType, out attribute))
return true;
} while (context.CompilerGeneratedState.TryGetOwningMethodForCompilerGeneratedMember (method, out method));

return false;
}
Expand Down
43 changes: 43 additions & 0 deletions src/linker/Linker/CallGraph.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Mono.Cecil;

namespace Mono.Linker
{
class CallGraph
{
readonly Dictionary<MethodDefinition, HashSet<MethodDefinition>> callGraph;

public CallGraph () => callGraph = new Dictionary<MethodDefinition, HashSet<MethodDefinition>> ();

public void TrackCall (MethodDefinition fromMethod, MethodDefinition toMethod)
{
if (!callGraph.TryGetValue (fromMethod, out HashSet<MethodDefinition>? toMethods)) {
toMethods = new HashSet<MethodDefinition> ();
callGraph.Add (fromMethod, toMethods);
}
toMethods.Add (toMethod);
}

public IEnumerable<MethodDefinition> GetReachableMethods (MethodDefinition start)
{
Queue<MethodDefinition> queue = new ();
HashSet<MethodDefinition> visited = new ();
visited.Add (start);
queue.Enqueue (start);
while (queue.TryDequeue (out MethodDefinition? method)) {
if (!callGraph.TryGetValue (method, out HashSet<MethodDefinition>? callees))
continue;

foreach (var callee in callees) {
if (visited.Add (callee)) {
queue.Enqueue (callee);
yield return callee;
}
}
}
}
}
}
56 changes: 56 additions & 0 deletions src/linker/Linker/CompilerGeneratedNames.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Mono.Linker
{
class CompilerGeneratedNames
{
internal static bool IsGeneratedMemberName (string memberName)
{
return memberName.Length > 0 && memberName[0] == '<';
}

internal static bool IsLambdaDisplayClass (string className)
{
if (!IsGeneratedMemberName (className))
return false;

// This is true for static lambdas (which are emitted into a class like <>c)
// and for instance lambdas (which are emitted into a class like <>c__DisplayClass1_0)
return className.StartsWith ("<>c");
}

internal static bool IsLambdaOrLocalFunction (string methodName) => IsLambdaMethod (methodName) || IsLocalFunction (methodName);

// Lambda methods have generated names like "<UserMethod>b__0_1" where "UserMethod" is the name
// of the original user code that contains the lambda method declaration.
internal static bool IsLambdaMethod (string methodName)
{
if (!IsGeneratedMemberName (methodName))
return false;

int i = methodName.IndexOf ('>', 1);
if (i == -1)
return false;

// Ignore the method ordinal/generation and lambda ordinal/generation.
return methodName[i + 1] == 'b';
}

// Local functions have generated names like "<UserMethod>g__LocalFunction|0_1" where "UserMethod" is the name
// of the original user code that contains the lambda method declaration, and "LocalFunction" is the name of
// the local function.
internal static bool IsLocalFunction (string methodName)
{
if (!IsGeneratedMemberName (methodName))
return false;

int i = methodName.IndexOf ('>', 1);
if (i == -1)
return false;

// Ignore the method ordinal/generation and local function ordinal/generation.
return methodName[i + 1] == 'g';
}
}
}
155 changes: 131 additions & 24 deletions src/linker/Linker/CompilerGeneratedState.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using ILLink.Shared;
using Mono.Cecil;
using Mono.Cecil.Cil;

namespace Mono.Linker
{
Expand All @@ -12,27 +15,68 @@ public class CompilerGeneratedState
{
readonly LinkContext _context;
readonly Dictionary<TypeDefinition, MethodDefinition> _compilerGeneratedTypeToUserCodeMethod;
readonly Dictionary<MethodDefinition, MethodDefinition> _compilerGeneratedMethodToUserCodeMethod;
readonly HashSet<TypeDefinition> _typesWithPopulatedCache;

public CompilerGeneratedState (LinkContext context)
{
_context = context;
_compilerGeneratedTypeToUserCodeMethod = new Dictionary<TypeDefinition, MethodDefinition> ();
_compilerGeneratedMethodToUserCodeMethod = new Dictionary<MethodDefinition, MethodDefinition> ();
_typesWithPopulatedCache = new HashSet<TypeDefinition> ();
}

static bool HasRoslynCompilerGeneratedName (TypeDefinition type) =>
type.Name.Contains ('<') || (type.DeclaringType != null && HasRoslynCompilerGeneratedName (type.DeclaringType));
static IEnumerable<TypeDefinition> GetCompilerGeneratedNestedTypes (TypeDefinition type)
{
foreach (var nestedType in type.NestedTypes) {
if (!CompilerGeneratedNames.IsGeneratedMemberName (nestedType.Name))
continue;

yield return nestedType;

foreach (var recursiveNestedType in GetCompilerGeneratedNestedTypes (nestedType))
yield return recursiveNestedType;
}
}

void PopulateCacheForType (TypeDefinition type)
{
// Avoid repeat scans of the same type
if (!_typesWithPopulatedCache.Add (type))
return;

foreach (MethodDefinition method in type.Methods) {
var callGraph = new CallGraph ();
var callingMethods = new HashSet<MethodDefinition> ();

void ProcessMethod (MethodDefinition method)
{
if (!CompilerGeneratedNames.IsLambdaOrLocalFunction (method.Name)) {
// If it's not a nested function, track as an entry point to the call graph.
var added = callingMethods.Add (method);
Debug.Assert (added);
}

// Discover calls or references to lambdas or local functions. This includes
// calls to local functions, and lambda assignments (which use ldftn).
if (method.Body != null) {
foreach (var instruction in method.Body.Instructions) {
if (instruction.OpCode.OperandType != OperandType.InlineMethod)
continue;

MethodDefinition? lambdaOrLocalFunction = _context.TryResolve ((MethodReference) instruction.Operand);
if (lambdaOrLocalFunction == null)
continue;

if (!CompilerGeneratedNames.IsLambdaOrLocalFunction (lambdaOrLocalFunction.Name))
continue;

callGraph.TrackCall (method, lambdaOrLocalFunction);
}
}

// Discover state machine methods.
if (!method.HasCustomAttributes)
continue;
return;

foreach (var attribute in method.CustomAttributes) {
if (attribute.AttributeType.Namespace != "System.Runtime.CompilerServices")
Expand All @@ -43,17 +87,53 @@ void PopulateCacheForType (TypeDefinition type)
case "AsyncStateMachineAttribute":
case "IteratorStateMachineAttribute":
TypeDefinition? stateMachineType = GetFirstConstructorArgumentAsType (attribute);
if (stateMachineType != null) {
if (!_compilerGeneratedTypeToUserCodeMethod.TryAdd (stateMachineType, method)) {
var alreadyAssociatedMethod = _compilerGeneratedTypeToUserCodeMethod[stateMachineType];
_context.LogWarning (new MessageOrigin (method), DiagnosticId.MethodsAreAssociatedWithStateMachine, method.GetDisplayName (), alreadyAssociatedMethod.GetDisplayName (), stateMachineType.GetDisplayName ());
}
if (stateMachineType == null)
break;
Debug.Assert (stateMachineType.DeclaringType == type ||
(CompilerGeneratedNames.IsGeneratedMemberName (stateMachineType.DeclaringType.Name) &&
stateMachineType.DeclaringType.DeclaringType == type));
if (!_compilerGeneratedTypeToUserCodeMethod.TryAdd (stateMachineType, method)) {
var alreadyAssociatedMethod = _compilerGeneratedTypeToUserCodeMethod[stateMachineType];
_context.LogWarning (new MessageOrigin (method), DiagnosticId.MethodsAreAssociatedWithStateMachine, method.GetDisplayName (), alreadyAssociatedMethod.GetDisplayName (), stateMachineType.GetDisplayName ());
}

break;
}
}
}

// Look for state machine methods, and methods which call local functions.
foreach (MethodDefinition method in type.Methods)
ProcessMethod (method);

// Also scan compiler-generated state machine methods (in case they have calls to nested functions),
// and nested functions inside compiler-generated closures (in case they call other nested functions).

// State machines can be emitted into lambda display classes, so we need to go down at least two
// levels to find calls from iterator nested functions to other nested functions. We just recurse into
// all compiler-generated nested types to avoid depending on implementation details.

foreach (var nestedType in GetCompilerGeneratedNestedTypes (type)) {
foreach (var method in nestedType.Methods)
ProcessMethod (method);
}

// Now we've discovered the call graphs for calls to nested functions.
// Use this to map back from nested functions to the declaring user methods.

// Note: This maps all nested functions back to the user code, not to the immediately
// declaring local function. The IL doesn't contain enough information in general for
// us to determine the nesting of local functions and lambdas.

// Note: this only discovers nested functions which are referenced from the user
// code or its referenced nested functions. There is no reliable way to determine from
// IL which user code an unused nested function belongs to.
foreach (var userDefinedMethod in callingMethods) {
foreach (var nestedFunction in callGraph.GetReachableMethods (userDefinedMethod)) {
Debug.Assert (CompilerGeneratedNames.IsLambdaOrLocalFunction (nestedFunction.Name));
_compilerGeneratedMethodToUserCodeMethod.Add (nestedFunction, userDefinedMethod);
}
}
}

static TypeDefinition? GetFirstConstructorArgumentAsType (CustomAttribute attribute)
Expand All @@ -64,27 +144,54 @@ void PopulateCacheForType (TypeDefinition type)
return attribute.ConstructorArguments[0].Value as TypeDefinition;
}

public MethodDefinition? GetUserDefinedMethodForCompilerGeneratedMember (IMemberDefinition sourceMember)
// For state machine types/members, maps back to the state machine method.
// For local functions and lambdas, maps back to the owning method in user code (not the declaring
// lambda or local function, because the IL doesn't contain enough information to figure this out).
public bool TryGetOwningMethodForCompilerGeneratedMember (IMemberDefinition sourceMember, [NotNullWhen (true)] out MethodDefinition? owningMethod)
{
owningMethod = null;
if (sourceMember == null)
return null;
return false;

TypeDefinition compilerGeneratedType = (sourceMember as TypeDefinition) ?? sourceMember.DeclaringType;
if (_compilerGeneratedTypeToUserCodeMethod.TryGetValue (compilerGeneratedType, out MethodDefinition? userDefinedMethod))
return userDefinedMethod;
MethodDefinition? compilerGeneratedMethod = sourceMember as MethodDefinition;
if (compilerGeneratedMethod != null) {
if (_compilerGeneratedMethodToUserCodeMethod.TryGetValue (compilerGeneratedMethod, out owningMethod))
return true;
}

// Only handle async or iterator state machine
// So go to the declaring type and check if it's compiler generated (as a perf optimization)
if (!HasRoslynCompilerGeneratedName (compilerGeneratedType) || compilerGeneratedType.DeclaringType == null)
return null;
TypeDefinition sourceType = (sourceMember as TypeDefinition) ?? sourceMember.DeclaringType;

if (_compilerGeneratedTypeToUserCodeMethod.TryGetValue (sourceType, out owningMethod))
return true;

if (!CompilerGeneratedNames.IsGeneratedMemberName (sourceMember.Name) && !CompilerGeneratedNames.IsGeneratedMemberName (sourceType.Name))
return false;

// sourceType is a state machine type, or the type containing a lambda or local function.
var typeToCache = sourceType;

// Now go to its declaring type and search all methods to find the one which points to the type as its
// Look in the declaring type if this is a compiler-generated type (state machine or display class).
// State machines can be emitted into display classes, so we may also need to go one more level up.
// To avoid depending on implementation details, we go up until we see a non-compiler-generated type.
// This is the counterpart to GetCompilerGeneratedNestedTypes.
while (typeToCache != null && CompilerGeneratedNames.IsGeneratedMemberName (typeToCache.Name))
typeToCache = typeToCache.DeclaringType;

if (typeToCache == null)
return false;

// Search all methods to find the one which points to the type as its
// state machine implementation.
PopulateCacheForType (compilerGeneratedType.DeclaringType);
if (_compilerGeneratedTypeToUserCodeMethod.TryGetValue (compilerGeneratedType, out userDefinedMethod))
return userDefinedMethod;
PopulateCacheForType (typeToCache);
if (compilerGeneratedMethod != null) {
if (_compilerGeneratedMethodToUserCodeMethod.TryGetValue (compilerGeneratedMethod, out owningMethod))
return true;
}

if (_compilerGeneratedTypeToUserCodeMethod.TryGetValue (sourceType, out owningMethod))
return true;

return null;
return false;
}
}
}
Loading

0 comments on commit cb11422

Please sign in to comment.