diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 5ba93fefab3f9e..43c0e2686a8c3b 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1725,15 +1725,35 @@ struct EmboxOpConversion : public EmboxCommonConversion { } }; -static bool isDeviceAllocation(mlir::Value val) { +static bool isDeviceAllocation(mlir::Value val, mlir::Value adaptorVal) { if (auto loadOp = mlir::dyn_cast_or_null(val.getDefiningOp())) - return isDeviceAllocation(loadOp.getMemref()); + return isDeviceAllocation(loadOp.getMemref(), {}); if (auto boxAddrOp = mlir::dyn_cast_or_null(val.getDefiningOp())) - return isDeviceAllocation(boxAddrOp.getVal()); + return isDeviceAllocation(boxAddrOp.getVal(), {}); if (auto convertOp = mlir::dyn_cast_or_null(val.getDefiningOp())) - return isDeviceAllocation(convertOp.getValue()); + return isDeviceAllocation(convertOp.getValue(), {}); + if (!val.getDefiningOp() && adaptorVal) { + if (auto blockArg = llvm::cast(adaptorVal)) { + if (blockArg.getOwner() && blockArg.getOwner()->getParentOp() && + blockArg.getOwner()->isEntryBlock()) { + if (auto func = mlir::dyn_cast_or_null( + *blockArg.getOwner()->getParentOp())) { + auto argAttrs = func.getArgAttrs(blockArg.getArgNumber()); + for (auto attr : argAttrs) { + if (attr.getName().getValue().ends_with(cuf::getDataAttrName())) { + auto dataAttr = + mlir::dyn_cast(attr.getValue()); + if (dataAttr.getValue() != cuf::DataAttribute::Pinned && + dataAttr.getValue() != cuf::DataAttribute::Unified) + return true; + } + } + } + } + } + } if (auto callOp = mlir::dyn_cast_or_null(val.getDefiningOp())) if (callOp.getCallee() && (callOp.getCallee().value().getRootReference().getValue().starts_with( @@ -1928,7 +1948,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion { if (fir::isDerivedTypeWithLenParams(boxTy)) TODO(loc, "fir.embox codegen of derived with length parameters"); mlir::Value result = placeInMemoryIfNotGlobalInit( - rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref())); + rewriter, loc, boxTy, dest, + isDeviceAllocation(xbox.getMemref(), adaptor.getMemref())); rewriter.replaceOp(xbox, result); return mlir::success(); } @@ -2052,9 +2073,9 @@ struct XReboxOpConversion : public EmboxCommonConversion { dest = insertStride(rewriter, loc, dest, dim, std::get<1>(iter.value())); } dest = insertBaseAddress(rewriter, loc, dest, base); - mlir::Value result = - placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest, - isDeviceAllocation(rebox.getBox())); + mlir::Value result = placeInMemoryIfNotGlobalInit( + rewriter, rebox.getLoc(), destBoxTy, dest, + isDeviceAllocation(rebox.getBox(), rebox.getBox())); rewriter.replaceOp(rebox, result); return mlir::success(); } diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir index 3ad28fa7bd5179..7ac89836a3ff16 100644 --- a/flang/test/Fir/CUDA/cuda-code-gen.mlir +++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir @@ -170,3 +170,20 @@ module attributes {dlti.dl_spec = #dlti.dl_spec = dense<32> : vec // CHECK-LABEL: llvm.func @_QQmain() // CHECK-COUNT-3: llvm.call @_FortranACUFAllocDescriptor + +// ----- + +module attributes {dlti.dl_spec = #dlti.dl_spec = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git@github.com:clementval/llvm-project.git efc2415bcce8e8a9e73e77aa122c8aba1c1fbbd2)", llvm.target_triple = "x86_64-unknown-linux-gnu"} { + func.func @_QPouter(%arg0: !fir.ref> {cuf.data_attr = #cuf.cuda, fir.bindc_name = "a"}) { + %c0_i32 = arith.constant 0 : i32 + %c100 = arith.constant 100 : index + %0 = fir.alloca tuple>> + %1 = fir.coordinate_of %0, %c0_i32 : (!fir.ref>>>, i32) -> !fir.ref>> + %2 = fircg.ext_embox %arg0(%c100, %c100) : (!fir.ref>, index, index) -> !fir.box> + fir.store %2 to %1 : !fir.ref>> + return + } +} + +// CHECK-LABEL: llvm.func @_QPouter +// CHECK: _FortranACUFAllocDescriptor