Skip to content

Commit

Permalink
[Backport to 15] [OpaquePointers] Adjust builtin variable tracking to…
Browse files Browse the repository at this point in the history
… support i8 geps (KhronosGroup#2061)

The existing logic for the replacement of builtin variables with calls to
functions relies on relatively brittle tracking that is broken when opaque
pointers is turned on, and will be even more thoroughly broken if/when typed
geps are replaced with i8 geps or ptradd. This patch replaces that logic with a
less brittle variant that is able to handle any sequence of bitcast, gep, or
addrspacecast instructions between the global variable and the ultimate load
instruction.

It still will error out if the variable is used in too insane of a fashion (say,
trying to load an i32 out of the i64, or a misaligned vector type).

Co-authored-by: Joshua Cranmer <joshua.cranmer@intel.com>
  • Loading branch information
mateuszchudyk and jcranmer-intel authored Jul 3, 2023
1 parent d04ccf4 commit 2289037
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 141 deletions.
2 changes: 1 addition & 1 deletion lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ std::string decodeSPIRVTypeName(StringRef Name,
SmallVectorImpl<std::string> &Strs);

// Copy attributes from function to call site.
void setAttrByCalledFunc(CallInst *Call);
CallInst *setAttrByCalledFunc(CallInst *Call);
bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind);
// Transform builtin variable from GlobalVariable to builtin call.
// e.g.
Expand Down
182 changes: 73 additions & 109 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1858,14 +1858,15 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
return true;
}

void setAttrByCalledFunc(CallInst *Call) {
CallInst *setAttrByCalledFunc(CallInst *Call) {
Function *F = Call->getCalledFunction();
assert(F);
if (F->isIntrinsic()) {
return;
return Call;
}
Call->setCallingConv(F->getCallingConv());
Call->setAttributes(F->getAttributes());
return Call;
}

bool isSPIRVBuiltinVariable(GlobalVariable *GV,
Expand Down Expand Up @@ -1915,6 +1916,75 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
// %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
// %5 = insertelement <3 x i64> %3, i64 %4, i32 2
// %6 = extractelement <3 x i64> %5, i32 0

/// Recursively look through the uses of a global variable, including casts or
/// gep offsets, to find all loads of the variable. Gep offsets that are non-0
/// are accumulated in the AccumulatedOffset parameter, which will eventually be
/// used to figure out which index of a variable is being used.
static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
Function *ReplacementFunc) {
const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
SmallVector<Instruction *, 4> InstsToRemove;
for (User *U : V->users()) {
if (auto *Cast = dyn_cast<CastInst>(U)) {
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
InstsToRemove.push_back(Cast);
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
APInt NewOffset = AccumulatedOffset.sextOrTrunc(
DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
if (!GEP->accumulateConstantOffset(DL, NewOffset))
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
InstsToRemove.push_back(GEP);
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
// Figure out which index the accumulated offset corresponds to. If we
// have a weird offset (e.g., trying to load byte 7), bail out.
Type *ScalarTy = ReplacementFunc->getReturnType();
APInt Index;
uint64_t Remainder;
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
Index, Remainder);
if (Remainder != 0)
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");

IRBuilder<> Builder(Load);
Value *Replacement;
if (ReplacementFunc->getFunctionType()->getNumParams() == 0) {
if (Load->getType() != ScalarTy)
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
Replacement =
setAttrByCalledFunc(Builder.CreateCall(ReplacementFunc, {}));
} else {
// The function has an index parameter.
if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
if (!Index.isZero())
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
Replacement = UndefValue::get(VecTy);
for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
Replacement = Builder.CreateInsertElement(
Replacement,
setAttrByCalledFunc(
Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
Builder.getInt32(I));
}
} else if (Load->getType() == ScalarTy) {
Replacement = setAttrByCalledFunc(Builder.CreateCall(
ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
} else {
llvm_unreachable("Illegal load type of a SPIR-V builtin variable");
}
}
Load->replaceAllUsesWith(Replacement);
InstsToRemove.push_back(Load);
} else {
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
}
}

for (Instruction *I : InstsToRemove)
I->eraseFromParent();
}

bool lowerBuiltinVariableToCall(GlobalVariable *GV,
SPIRVBuiltinVariableKind Kind) {
// There might be dead constant users of GV (for example, SPIRVLowerConstExpr
Expand Down Expand Up @@ -1950,113 +2020,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
Func->setDoesNotAccessMemory();
}

// Collect instructions in these containers to remove them later.
std::vector<Instruction *> Loads;
std::vector<Instruction *> Casts;
std::vector<Instruction *> GEPs;

auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
auto *Call = CallInst::Create(Func, Arg, "", I);
Call->takeName(I);
setAttrByCalledFunc(Call);
SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call
<< '\n';)
I->replaceAllUsesWith(Call);
};

// If HasIndexArg is true, we create 3 built-in calls and insertelement to
// get 3-element vector filled with ids and replace uses of Load instruction
// with this vector.
// If HasIndexArg is false, the result of the Load instruction is the value
// which should be replaced with the Func.
// Returns true if Load was replaced, false otherwise.
auto ReplaceIfLoad = [&](User *I) {
auto *LD = dyn_cast<LoadInst>(I);
if (!LD)
return false;
std::vector<Value *> Vectors;
Loads.push_back(LD);
if (HasIndexArg) {
auto *VecTy = cast<FixedVectorType>(GVTy);
Value *EmptyVec = UndefValue::get(VecTy);
Vectors.push_back(EmptyVec);
const DebugLoc &DLoc = LD->getDebugLoc();
for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I);
auto *Call = CallInst::Create(Func, {Idx}, "", LD);
if (DLoc)
Call->setDebugLoc(DLoc);
setAttrByCalledFunc(Call);
auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx);
if (DLoc)
Insert->setDebugLoc(DLoc);
Insert->insertAfter(Call);
Vectors.push_back(Insert);
}

Value *Ptr = LD->getPointerOperand();

if (isa<FixedVectorType>(LD->getType())) {
LD->replaceAllUsesWith(Vectors.back());
} else {
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
assert(GEP && "Unexpected pattern!");
assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
Value *Idx = GEP->getOperand(2);
Value *Vec = Vectors.back();
auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
NewExtract->insertAfter(cast<Instruction>(Vec));
LD->replaceAllUsesWith(NewExtract);
}

} else {
Replace({}, LD);
}

return true;
};

// Go over the GV users, find Load and ExtractElement instructions and
// replace them with the corresponding function call.
for (auto *UI : GV->users()) {
// There might or might not be an addrspacecast instruction.
if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
Casts.push_back(ASCast);
for (auto *CastUser : ASCast->users()) {
if (ReplaceIfLoad(CastUser))
continue;
if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
GEPs.push_back(GEP);
for (auto *GEPUser : GEP->users()) {
if (!ReplaceIfLoad(GEPUser))
llvm_unreachable("Unexpected pattern!");
}
} else {
llvm_unreachable("Unexpected pattern!");
}
}
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(UI)) {
GEPs.push_back(GEP);
for (auto *GEPUser : GEP->users()) {
if (!ReplaceIfLoad(GEPUser))
llvm_unreachable("Unexpected pattern!");
}
} else if (!ReplaceIfLoad(UI)) {
llvm_unreachable("Unexpected pattern!");
}
}

auto Erase = [](std::vector<Instruction *> &ToErase) {
for (Instruction *I : ToErase) {
assert(I->hasNUses(0));
I->eraseFromParent();
}
};
// Order of erasing is important.
Erase(Loads);
Erase(GEPs);
Erase(Casts);

replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
return true;
}

Expand Down
42 changes: 23 additions & 19 deletions test/builtin-vars-gep.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,32 @@ target triple = "spir64"
define spir_func void @foo() {
entry:
%GroupID = alloca [3 x i64], align 8
%0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
%1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
%0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
%1 = getelementptr <3 x i64>, ptr addrspace(4) %0, i64 0, i64 0
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
; CHECK-LLVM: %[[Ins0:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize0]], i32 0
; CHECK-LLVM: %[[GLocalSize1:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
; CHECK-LLVM: %[[Ins1:[0-9]+]] = insertelement <3 x i64> %[[Ins0]], i64 %[[GLocalSize1]], i32 1
%2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
%3 = getelementptr <3 x i64>, ptr addrspace(4) %2, i64 0, i64 2
%4 = load i64, ptr addrspace(4) %1, align 32
%5 = load i64, ptr addrspace(4) %3, align 8
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
; CHECK-LLVM: %[[Ins2:[0-9]+]] = insertelement <3 x i64> %[[Ins1]], i64 %[[GLocalSize2]], i32 2
; CHECK-LLVM: %[[Extract:[0-9]+]] = extractelement <3 x i64> %[[Ins2]], i64 0
%2 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
%3 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %2, i64 0, i64 2
%4 = load i64, i64 addrspace(4)* %1, align 32
%5 = load i64, i64 addrspace(4)* %3, align 8
; CHECK-LLVM: %[[GLocalSize01:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
; CHECK-LLVM: %[[Ins01:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize01]], i32 0
; CHECK-LLVM: %[[GLocalSize11:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
; CHECK-LLVM: %[[Ins11:[0-9]+]] = insertelement <3 x i64> %[[Ins01]], i64 %[[GLocalSize11]], i32 1
; CHECK-LLVM: %[[GLocalSize21:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
; CHECK-LLVM: %[[Ins21:[0-9]+]] = insertelement <3 x i64> %[[Ins11]], i64 %[[GLocalSize21]], i32 2
; CHECK-LLVM: %[[Extract1:[0-9]+]] = extractelement <3 x i64> %[[Ins21]], i64 2
; CHECK-LLVM: mul i64 %[[Extract]], %[[Extract1]]
; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
%mul = mul i64 %4, %5
ret void
}

; Function Attrs: alwaysinline convergent nounwind mustprogress
define spir_func void @foo_i8gep() {
entry:
%GroupID = alloca [3 x i64], align 8
%0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
%1 = getelementptr i8, ptr addrspace(4) %0, i64 0
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
%2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
%3 = getelementptr i8, ptr addrspace(4) %2, i64 16
%4 = load i64, ptr addrspace(4) %1, align 32
%5 = load i64, ptr addrspace(4) %3, align 8
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
%mul = mul i64 %4, %5
ret void
}
12 changes: 0 additions & 12 deletions test/transcoding/builtin_vars_gep.ll
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,8 @@ define spir_kernel void @f() {
entry:
%0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32
; CHECK-OCL-IR: %[[#ID1:]] = call spir_func i64 @_Z13get_global_idj(i32 0)
; CHECK-OCL-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
; CHECK-OCL-IR: %[[#ID2:]] = call spir_func i64 @_Z13get_global_idj(i32 1)
; CHECK-OCL-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
; CHECK-OCL-IR: %[[#ID3:]] = call spir_func i64 @_Z13get_global_idj(i32 2)
; CHECK-OCL-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
; CHECK-OCL-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0

; CHECK-SPV-IR: %[[#ID1:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0)
; CHECK-SPV-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1)
; CHECK-SPV-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
; CHECK-SPV-IR: %[[#ID3:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2)
; CHECK-SPV-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
; CHECK-SPV-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0

ret void
}

0 comments on commit 2289037

Please sign in to comment.