From 616d53fca1be7ca03c658e5c07ad6201e0537cc9 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 17 May 2023 16:39:47 -0700 Subject: [PATCH] Only call Free in unmanaged->managed stubs when ownership has been transfered to the callee Fixes #85795 --- ...ributedMarshallingModelGeneratorFactory.cs | 7 +- .../MarshalAsMarshallingGeneratorFactory.cs | 2 +- .../StatelessMarshallingStrategy.cs | 87 ++++++++++ .../NativeToManagedStubCodeContext.cs | 2 +- .../IDerivedTests.cs | 9 +- .../ImplicitThisTests.cs | 25 ++- ...nmanagedToManagedCustomMarshallingTests.cs | 163 ++++++++++++++++++ .../NativeExports/VirtualMethodTables.cs | 19 ++ 8 files changed, 298 insertions(+), 16 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 03b90462ea3013..2266edee6d10b4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -245,7 +245,12 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false); if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) - marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + { + if (context.Direction == MarshalDirection.ManagedToUnmanaged) + marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + else if (info.RefKind == RefKind.Ref) + marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + } } IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs index 6bf87839e5846d..1790e588a8e790 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs @@ -88,7 +88,7 @@ public IMarshallingGenerator Create( return s_delegate; case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }: - if (!context.AdditionalTemporaryStateLivesAcrossStages) + if (!context.AdditionalTemporaryStateLivesAcrossStages || context.Direction != MarshalDirection.ManagedToUnmanaged) { throw new MarshallingNotSupportedException(info, context); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 94ea82a31b6d0a..000cba627ef811 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -295,6 +295,93 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); } + internal sealed class StatelessByRefFreeMarshalling : ICustomTypeMarshallingStrategy + { + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + private readonly TypeSyntax _marshallerType; + + public StatelessByRefFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, TypeSyntax marshallerType) + { + _innerMarshaller = innerMarshaller; + _marshallerType = marshallerType; + } + + public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context)) + { + yield return statement; + } + // if () + // .Free(); + yield return IfStatement( + IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + _marshallerType, + IdentifierName(ShapeMemberNames.Free)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(context.GetAdditionalIdentifier(info, "original")))))))); + } + + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context); + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context)) + { + yield return statement; + } + + // bool = false; + yield return LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.BoolKeyword)), + SingletonSeparatedList( + VariableDeclarator( + Identifier(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + null, + EqualsValueClause( + LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); + + // = ; + yield return LocalDeclarationStatement( + VariableDeclaration( + AsNativeType(info).Syntax, + SingletonSeparatedList( + VariableDeclarator( + Identifier(context.GetAdditionalIdentifier(info, "original")), + null, + EqualsValueClause( + IdentifierName(context.GetIdentifiers(info).native)))))); + } + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context)) + { + yield return statement; + } + + // Now that we've captured the new value to pass to the caller, we need to make sure that we free the old one. + + // = true; + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + LiteralExpression(SyntaxKind.TrueLiteralExpression))); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); + } + /// /// Marshaller type that enables allocating space for marshalling a linear collection using a marshaller that implements the LinearCollection marshalling spec. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs index 3dbd182406643d..7940f6a15fd233 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs @@ -11,7 +11,7 @@ public sealed record NativeToManagedStubCodeContext : StubCodeContext { public override bool SingleFrameSpansNativeContext => false; - public override bool AdditionalTemporaryStateLivesAcrossStages => false; + public override bool AdditionalTemporaryStateLivesAcrossStages => true; private readonly TargetFramework _framework; private readonly Version _frameworkVersion; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs index 937ba3de524ef2..68c7420d0b3cbc 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs @@ -48,10 +48,9 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() iface.SetInt(5); Assert.Equal(5, iface.GetInt()); - // /~https://github.com/dotnet/runtime/issues/85795 - //Assert.Equal("myName", iface.GetName()); - //iface.SetName("updated"); - //Assert.Equal("updated", iface.GetName()); + Assert.Equal("myName", iface.GetName()); + iface.SetName("updated"); + Assert.Equal("updated", iface.GetName()); var iUnknownStrategyProperty = typeof(ComObject).GetProperty("IUnknownStrategy", BindingFlags.NonPublic | BindingFlags.Instance); @@ -67,7 +66,7 @@ partial class DerivedImpl : IDerived { int data = 3; string myName = "myName"; - public void DoThingWithString([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => throw new NotImplementedException(); + public void DoThingWithString(string name) => throw new NotImplementedException(); public int GetInt() => data; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs index e6214e759450e0..d1995087f9279f 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs @@ -5,6 +5,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; +using System.Threading; using ComInterfaceGenerator.Tests; using Xunit; @@ -35,6 +36,8 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation int GetData(); [VirtualMethodIndex(1, ImplicitThisParameter = true)] void SetData(int x); + [VirtualMethodIndex(2, ImplicitThisParameter = true)] + void ExchangeData(ref int x); } [NativeMarshalling(typeof(NativeObjectMarshaller))] @@ -105,16 +108,21 @@ public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed() void* wrapper = VTableGCHandlePair.Allocate(impl); - Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); - NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue); - Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); - // Verify that we actually updated the managed instance. - Assert.Equal(newValue, impl.GetData()); - - VTableGCHandlePair.Free(wrapper); + try + { + Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); + NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue); + Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); + // Verify that we actually updated the managed instance. + Assert.Equal(newValue, impl.GetData()); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } } - class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject + sealed class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject { private int _data; @@ -123,6 +131,7 @@ public ManagedObjectImplementation(int value) _data = value; } + public void ExchangeData(ref int x) => x = Interlocked.Exchange(ref _data, x); public int GetData() => _data; public void SetData(int x) => _data = x; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs new file mode 100644 index 00000000000000..d13c0544dd5fe9 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -0,0 +1,163 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using SharedTypes; +using Xunit; +using static ComInterfaceGenerator.Tests.UnmanagedToManagedCustomMarshallingTests; + +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal partial class UnmanagedToManagedCustomMarshalling + { + [UnmanagedObjectUnwrapper>] + internal partial interface INativeObject : IUnmanagedInterfaceType + { + + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 2); + static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation + { + get + { + if (s_vtable[0] == null) + { + Native.PopulateUnmanagedVirtualMethodTable(s_vtable); + } + return s_vtable; + } + } + + [VirtualMethodIndex(0, ImplicitThisParameter = true)] + [return: MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] + IntWrapper GetData(); + [VirtualMethodIndex(1, ImplicitThisParameter = true)] + void SetData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] IntWrapper x); + [VirtualMethodIndex(2, ImplicitThisParameter = true)] + void ExchangeData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] ref IntWrapper data); + } + + [NativeMarshalling(typeof(NativeObjectMarshaller))] + public class NativeObject : INativeObject.Native, IUnmanagedVirtualMethodTableProvider, IDisposable + { + private readonly void* _pointer; + + public NativeObject(void* pointer) + { + _pointer = pointer; + } + + public VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(Type type) + { + Assert.Equal(typeof(INativeObject), type); + return new VirtualMethodTableInfo(_pointer, *(void***)_pointer); + } + + public void Dispose() + { + DeleteNativeObject(_pointer); + } + } + + [CustomMarshaller(typeof(NativeObject), MarshalMode.ManagedToUnmanagedOut, typeof(NativeObjectMarshaller))] + static class NativeObjectMarshaller + { + public static NativeObject ConvertToManaged(void* value) => new NativeObject(value); + } + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "new_native_object")] + public static partial NativeObject NewNativeObject(); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "delete_native_object")] + public static partial void DeleteNativeObject(void* obj); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_native_object_data")] + public static partial void SetNativeObjectData(void* obj, int data); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_native_object_data")] + public static partial int GetNativeObjectData(void* obj); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "exchange_native_object_data")] + public static partial int ExchangeNativeObjectData(void* obj, ref int x); + } + } + public class UnmanagedToManagedCustomMarshallingTests + { + [Fact] + public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed() + { + const int startingValue = 13; + const int newValue = 42; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + NativeExportsNE.UnmanagedToManagedCustomMarshalling.GetNativeObjectData(wrapper); + + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SetNativeObjectData(wrapper, newValue); + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + + int finalValue = 10; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.ExchangeNativeObjectData(wrapper, ref finalValue); + Assert.Equal(freeCalls + 1, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + sealed class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject + { + private IntWrapper _data; + + public ManagedObjectImplementation(int value) + { + _data = new() { i = value }; + } + + public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x); + public IntWrapper GetData() => _data; + public void SetData(IntWrapper x) => _data = x; + } + + + [CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerToIntWithFreeCounts))] + public static unsafe class IntWrapperMarshallerToIntWithFreeCounts + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + public static int ConvertToUnmanaged(IntWrapper managed) + { + return managed.i; + } + + public static IntWrapper ConvertToManaged(int unmanaged) + { + return new IntWrapper { i = unmanaged }; + } + + public static void Free(int unmanaged) + { + NumCallsToFree++; + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs index a286dd4de36147..f091146b7ed737 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs @@ -8,6 +8,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace NativeExports @@ -48,6 +49,7 @@ public struct VirtualFunctionTable { public delegate* unmanaged getData; public delegate* unmanaged setData; + public delegate* unmanaged exchangeData; } public readonly VirtualFunctionTable* VTable; @@ -66,12 +68,14 @@ public struct VirtualFunctionTable // The order of functions here should match NativeObjectInterface.VirtualFunctionTable's members. public delegate* unmanaged getData; public delegate* unmanaged setData; + public delegate* unmanaged exchangeData; } static NativeObject() { VTablePointer = (VirtualFunctionTable*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NativeObject), sizeof(VirtualFunctionTable)); VTablePointer->getData = &GetData; VTablePointer->setData = &SetData; + VTablePointer->exchangeData = &ExchangeData; } private static readonly VirtualFunctionTable* VTablePointer; @@ -95,6 +99,14 @@ private static void SetData(NativeObject* obj, int value) { obj->Data = value; } + + [UnmanagedCallersOnly] + private static void ExchangeData(NativeObject* obj, int* value) + { + var temp = obj->Data; + obj->Data = *value; + *value = temp; + } } [UnmanagedCallersOnly(EntryPoint = "new_native_object")] @@ -127,5 +139,12 @@ public static int GetNativeObjectData([DNNE.C99Type("struct INativeObject*")] Na { return obj->VTable->getData(obj); } + + [UnmanagedCallersOnly(EntryPoint = "exchange_native_object_data")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void ExchangeNativeObjectData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* x) + { + obj->VTable->exchangeData(obj, x); + } } }