Skip to content

Commit

Permalink
Only call Free in unmanaged->managed stubs when ownership has been tr…
Browse files Browse the repository at this point in the history
…ansfered to the callee

Fixes dotnet#85795
  • Loading branch information
jkoritzinsky committed May 19, 2023
1 parent 4dd7e87 commit 616d53f
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false);

if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax);
{
if (context.Direction == MarshalDirection.ManagedToUnmanaged)
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax);
else if (info.RefKind == RefKind.Ref)
marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax);
}
}

IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public IMarshallingGenerator Create(
return s_delegate;

case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }:
if (!context.AdditionalTemporaryStateLivesAcrossStages)
if (!context.AdditionalTemporaryStateLivesAcrossStages || context.Direction != MarshalDirection.ManagedToUnmanaged)
{
throw new MarshallingNotSupportedException(info, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,93 @@ public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo i
public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
}

internal sealed class StatelessByRefFreeMarshalling : ICustomTypeMarshallingStrategy
{
private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
private readonly TypeSyntax _marshallerType;

public StatelessByRefFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, TypeSyntax marshallerType)
{
_innerMarshaller = innerMarshaller;
_marshallerType = marshallerType;
}

public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);

public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context))
{
yield return statement;
}
// if (<freeUnmanaged>)
// <marshallerType>.Free(<original>);
yield return IfStatement(
IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")),
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
_marshallerType,
IdentifierName(ShapeMemberNames.Free)),
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(context.GetAdditionalIdentifier(info, "original"))))))));
}

public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context);
public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
{
foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context))
{
yield return statement;
}

// bool <freeUnmanaged> = false;
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.BoolKeyword)),
SingletonSeparatedList(
VariableDeclarator(
Identifier(context.GetAdditionalIdentifier(info, "freeUnmanaged")),
null,
EqualsValueClause(
LiteralExpression(SyntaxKind.FalseLiteralExpression))))));

// <nativeType> <original> = <nativeIdentifier>;
yield return LocalDeclarationStatement(
VariableDeclaration(
AsNativeType(info).Syntax,
SingletonSeparatedList(
VariableDeclarator(
Identifier(context.GetAdditionalIdentifier(info, "original")),
null,
EqualsValueClause(
IdentifierName(context.GetIdentifiers(info).native))))));
}

public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context)
{
foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context))
{
yield return statement;
}

// Now that we've captured the new value to pass to the caller, we need to make sure that we free the old one.

// <freeUnmanaged> = true;
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")),
LiteralExpression(SyntaxKind.TrueLiteralExpression)));
}

public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
}

/// <summary>
/// Marshaller type that enables allocating space for marshalling a linear collection using a marshaller that implements the LinearCollection marshalling spec.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public sealed record NativeToManagedStubCodeContext : StubCodeContext
{
public override bool SingleFrameSpansNativeContext => false;

public override bool AdditionalTemporaryStateLivesAcrossStages => false;
public override bool AdditionalTemporaryStateLivesAcrossStages => true;

private readonly TargetFramework _framework;
private readonly Version _frameworkVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce()
iface.SetInt(5);
Assert.Equal(5, iface.GetInt());

// /~https://github.com/dotnet/runtime/issues/85795
//Assert.Equal("myName", iface.GetName());
//iface.SetName("updated");
//Assert.Equal("updated", iface.GetName());
Assert.Equal("myName", iface.GetName());
iface.SetName("updated");
Assert.Equal("updated", iface.GetName());

var iUnknownStrategyProperty = typeof(ComObject).GetProperty("IUnknownStrategy", BindingFlags.NonPublic | BindingFlags.Instance);

Expand All @@ -67,7 +66,7 @@ partial class DerivedImpl : IDerived
{
int data = 3;
string myName = "myName";
public void DoThingWithString([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => throw new NotImplementedException();
public void DoThingWithString(string name) => throw new NotImplementedException();

public int GetInt() => data;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Threading;
using ComInterfaceGenerator.Tests;
using Xunit;

Expand Down Expand Up @@ -35,6 +36,8 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation
int GetData();
[VirtualMethodIndex(1, ImplicitThisParameter = true)]
void SetData(int x);
[VirtualMethodIndex(2, ImplicitThisParameter = true)]
void ExchangeData(ref int x);
}

[NativeMarshalling(typeof(NativeObjectMarshaller))]
Expand Down Expand Up @@ -105,16 +108,21 @@ public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed()

void* wrapper = VTableGCHandlePair<NativeExportsNE.ImplicitThis.INativeObject>.Allocate(impl);

Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue);
Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
// Verify that we actually updated the managed instance.
Assert.Equal(newValue, impl.GetData());

VTableGCHandlePair<NativeExportsNE.ImplicitThis.INativeObject>.Free(wrapper);
try
{
Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue);
Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
// Verify that we actually updated the managed instance.
Assert.Equal(newValue, impl.GetData());
}
finally
{
VTableGCHandlePair<NativeExportsNE.ImplicitThis.INativeObject>.Free(wrapper);
}
}

class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject
sealed class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject
{
private int _data;

Expand All @@ -123,6 +131,7 @@ public ManagedObjectImplementation(int value)
_data = value;
}

public void ExchangeData(ref int x) => x = Interlocked.Exchange(ref _data, x);
public int GetData() => _data;
public void SetData(int x) => _data = x;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using SharedTypes;
using Xunit;
using static ComInterfaceGenerator.Tests.UnmanagedToManagedCustomMarshallingTests;

namespace ComInterfaceGenerator.Tests
{
internal unsafe partial class NativeExportsNE
{
internal partial class UnmanagedToManagedCustomMarshalling
{
[UnmanagedObjectUnwrapper<VTableGCHandlePair<INativeObject>>]
internal partial interface INativeObject : IUnmanagedInterfaceType
{

private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 2);
static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation
{
get
{
if (s_vtable[0] == null)
{
Native.PopulateUnmanagedVirtualMethodTable(s_vtable);
}
return s_vtable;
}
}

[VirtualMethodIndex(0, ImplicitThisParameter = true)]
[return: MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))]
IntWrapper GetData();
[VirtualMethodIndex(1, ImplicitThisParameter = true)]
void SetData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] IntWrapper x);
[VirtualMethodIndex(2, ImplicitThisParameter = true)]
void ExchangeData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] ref IntWrapper data);
}

[NativeMarshalling(typeof(NativeObjectMarshaller))]
public class NativeObject : INativeObject.Native, IUnmanagedVirtualMethodTableProvider, IDisposable
{
private readonly void* _pointer;

public NativeObject(void* pointer)
{
_pointer = pointer;
}

public VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(Type type)
{
Assert.Equal(typeof(INativeObject), type);
return new VirtualMethodTableInfo(_pointer, *(void***)_pointer);
}

public void Dispose()
{
DeleteNativeObject(_pointer);
}
}

[CustomMarshaller(typeof(NativeObject), MarshalMode.ManagedToUnmanagedOut, typeof(NativeObjectMarshaller))]
static class NativeObjectMarshaller
{
public static NativeObject ConvertToManaged(void* value) => new NativeObject(value);
}

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "new_native_object")]
public static partial NativeObject NewNativeObject();

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "delete_native_object")]
public static partial void DeleteNativeObject(void* obj);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_native_object_data")]
public static partial void SetNativeObjectData(void* obj, int data);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_native_object_data")]
public static partial int GetNativeObjectData(void* obj);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "exchange_native_object_data")]
public static partial int ExchangeNativeObjectData(void* obj, ref int x);
}
}
public class UnmanagedToManagedCustomMarshallingTests
{
[Fact]
public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed()
{
const int startingValue = 13;
const int newValue = 42;

ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);

void* wrapper = VTableGCHandlePair<NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject>.Allocate(impl);

try
{
int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
NativeExportsNE.UnmanagedToManagedCustomMarshalling.GetNativeObjectData(wrapper);

Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);

NativeExportsNE.UnmanagedToManagedCustomMarshalling.SetNativeObjectData(wrapper, newValue);
Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);

int finalValue = 10;

NativeExportsNE.UnmanagedToManagedCustomMarshalling.ExchangeNativeObjectData(wrapper, ref finalValue);
Assert.Equal(freeCalls + 1, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
}
finally
{
VTableGCHandlePair<NativeExportsNE.ImplicitThis.INativeObject>.Free(wrapper);
}
}

sealed class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject
{
private IntWrapper _data;

public ManagedObjectImplementation(int value)
{
_data = new() { i = value };
}

public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x);
public IntWrapper GetData() => _data;
public void SetData(IntWrapper x) => _data = x;
}


[CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerToIntWithFreeCounts))]
public static unsafe class IntWrapperMarshallerToIntWithFreeCounts
{
[ThreadStatic]
public static int NumCallsToFree = 0;

public static int ConvertToUnmanaged(IntWrapper managed)
{
return managed.i;
}

public static IntWrapper ConvertToManaged(int unmanaged)
{
return new IntWrapper { i = unmanaged };
}

public static void Free(int unmanaged)
{
NumCallsToFree++;
}
}
}
}
Loading

0 comments on commit 616d53f

Please sign in to comment.