Skip to content

Commit

Permalink
Move OCL-IR translation logic related to EIS instructions to SPIRVToO…
Browse files Browse the repository at this point in the history
…CL pass

This change also fixes vstore translation for SPV-IR generation path and adds
test for vstore and printf.
  • Loading branch information
aratajew authored and AlexeySotkin committed May 14, 2021
1 parent ebaacc0 commit 67d3e27
Show file tree
Hide file tree
Showing 13 changed files with 1,612 additions and 78 deletions.
4 changes: 3 additions & 1 deletion lib/SPIRV/OCLUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,9 @@ class OCLBuiltinFuncMangleInfo : public SPIRV::BuiltinFuncMangleInfo {
if (NameRef.startswith("async_work_group")) {
addUnsignedArg(-1);
setArgAttr(1, SPIR::ATTR_CONST);
} else if (NameRef.startswith("write_imageui"))
} else if (NameRef.startswith("printf"))
setVarArg(1);
else if (NameRef.startswith("write_imageui"))
addUnsignedArg(2);
else if (NameRef.equals("prefetch")) {
addUnsignedArg(1);
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,10 @@ template <> inline void SPIRVMap<std::string, Op, SPIRVOpaqueType>::init() {
// Check if the module contains llvm.loop.* metadata
bool hasLoopMetadata(const Module *M);

// Check if CI is a call to instruction from OpenCL Extended Instruction Set.
// If so, return it's extended opcode in ExtOp.
bool isSPIRVOCLExtInst(const CallInst *CI, OCLExtOpKind *ExtOp);

// check LLVM Intrinsics type(s) for validity
bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM);
} // namespace SPIRV
Expand Down
87 changes: 12 additions & 75 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4224,88 +4224,25 @@ bool SPIRVToLLVM::transAlign(SPIRVValue *BV, Value *V) {
return true;
}

void SPIRVToLLVM::transOCLVectorLoadStore(std::string &UnmangledName,
std::vector<SPIRVWord> &BArgs) {
if (UnmangledName.find("vload") == 0 &&
UnmangledName.find("n") != std::string::npos) {
if (BArgs.back() != 1) {
std::stringstream SS;
SS << BArgs.back();
UnmangledName.replace(UnmangledName.find("n"), 1, SS.str());
} else {
UnmangledName.erase(UnmangledName.find("n"), 1);
}
BArgs.pop_back();
} else if (UnmangledName.find("vstore") == 0) {
if (UnmangledName.find("n") != std::string::npos) {
auto T = BM->getValueType(BArgs[0]);
if (T->isTypeVector()) {
auto W = T->getVectorComponentCount();
std::stringstream SS;
SS << W;
UnmangledName.replace(UnmangledName.find("n"), 1, SS.str());
} else {
UnmangledName.erase(UnmangledName.find("n"), 1);
}
}
if (UnmangledName.find("_r") != std::string::npos) {
UnmangledName.replace(
UnmangledName.find("_r"), 2,
std::string("_") +
SPIRSPIRVFPRoundingModeMap::rmap(
static_cast<SPIRVFPRoundingModeKind>(BArgs.back())));
BArgs.pop_back();
}
}
}

// printf is not mangled. The function type should have just one argument.
// read_image*: the second argument should be mangled as sampler.
Instruction *SPIRVToLLVM::transOCLBuiltinFromExtInst(SPIRVExtInst *BC,
BasicBlock *BB) {
assert(BB && "Invalid BB");
std::string MangledName;
SPIRVWord EntryPoint = BC->getExtOp();
std::string UnmangledName;
std::vector<SPIRVWord> BArgs = BC->getArguments();
auto ExtOp = static_cast<OCLExtOpKind>(BC->getExtOp());
std::string UnmangledName = OCLExtOpMap::map(ExtOp);

assert(BM->getBuiltinSet(BC->getExtSetId()) == SPIRVEIS_OpenCL &&
"Not OpenCL extended instruction");

bool IsPrintf = (EntryPoint == OpenCLLIB::Printf);
UnmangledName = OCLExtOpMap::map(static_cast<OCLExtOpKind>(EntryPoint));

SPIRVDBG(spvdbgs() << "[transOCLBuiltinFromExtInst] OrigUnmangledName: "
<< UnmangledName << '\n');
transOCLVectorLoadStore(UnmangledName, BArgs);

std::vector<Type *> ArgTypes = transTypeVector(BC->getValueTypes(BArgs));

// TODO: we should always produce SPIR-V friendly IR and apply lowering
// later if needed
if (IsPrintf) {
ArgTypes.resize(1);
}

std::vector<Type *> ArgTypes = transTypeVector(BC->getArgTypes());
Type *RetTy = transType(BC->getType());
if (BM->getDesiredBIsRepresentation() != BIsRepresentation::SPIRVFriendlyIR) {
// Convert extended instruction into an OpenCL built-in
if (IsPrintf) {
MangledName = "printf";
} else {
mangleOpenClBuiltin(UnmangledName, ArgTypes, MangledName);
}
} else {
MangledName = getSPIRVFriendlyIRFunctionName(
static_cast<OCLExtOpKind>(EntryPoint), ArgTypes, RetTy);
}
std::string MangledName =
getSPIRVFriendlyIRFunctionName(ExtOp, ArgTypes, RetTy);

SPIRVDBG(spvdbgs() << "[transOCLBuiltinFromExtInst] ModifiedUnmangledName: "
SPIRVDBG(spvdbgs() << "[transOCLBuiltinFromExtInst] UnmangledName: "
<< UnmangledName << " MangledName: " << MangledName
<< '\n');

FunctionType *FT = FunctionType::get(RetTy, ArgTypes,
/* IsVarArg */ IsPrintf);
FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
Function *F = M->getFunction(MangledName);
if (!F) {
F = Function::Create(FT, GlobalValue::ExternalLinkage, MangledName, M);
Expand All @@ -4315,17 +4252,17 @@ Instruction *SPIRVToLLVM::transOCLBuiltinFromExtInst(SPIRVExtInst *BC,
if (isFuncReadNone(UnmangledName))
F->addFnAttr(Attribute::ReadNone);
}
auto Args = transValue(BC->getValues(BArgs), F, BB);
auto Args = transValue(BC->getArgValues(), F, BB);
SPIRVDBG(dbgs() << "[transOCLBuiltinFromExtInst] Function: " << *F
<< ", Args: ";
for (auto &I
: Args) dbgs()
<< *I << ", ";
dbgs() << '\n');
CallInst *Call = CallInst::Create(F, Args, BC->getName(), BB);
setCallingConv(Call);
addFnAttr(Call, Attribute::NoUnwind);
return transOCLBuiltinPostproc(BC, Call, BB, UnmangledName);
CallInst *CI = CallInst::Create(F, Args, BC->getName(), BB);
setCallingConv(CI);
addFnAttr(CI, Attribute::NoUnwind);
return CI;
}

// SPIR-V only contains language version. Use OpenCL language version as
Expand Down
2 changes: 0 additions & 2 deletions lib/SPIRV/SPIRVReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class SPIRVToLLVM {
SPIRVInstruction *BI, BasicBlock *BB);
Instruction *transOCLBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
Instruction *transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
void transOCLVectorLoadStore(std::string &UnmangledName,
std::vector<SPIRVWord> &BArgs);

/// Post-process translated LLVM module for OpenCL.
bool postProcessOCL();
Expand Down
109 changes: 109 additions & 0 deletions lib/SPIRV/SPIRVToOCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ void SPIRVToOCL::visitCallInst(CallInst &CI) {
if (!F)
return;

OCLExtOpKind ExtOp;
if (isSPIRVOCLExtInst(&CI, &ExtOp)) {
switch (ExtOp) {
case OpenCLLIB::Vloadn:
case OpenCLLIB::Vloada_halfn:
case OpenCLLIB::Vload_halfn:
visitCallSPIRVVLoadn(&CI, ExtOp);
break;
case OpenCLLIB::Vstoren:
case OpenCLLIB::Vstore_halfn:
case OpenCLLIB::Vstorea_halfn:
case OpenCLLIB::Vstore_half_r:
case OpenCLLIB::Vstore_halfn_r:
case OpenCLLIB::Vstorea_halfn_r:
visitCallSPIRVVStore(&CI, ExtOp);
break;
case OpenCLLIB::Printf:
visitCallSPIRVPrintf(&CI, ExtOp);
break;
default:
visitCallSPIRVOCLExt(&CI, ExtOp);
break;
}
return;
}

auto MangledName = F->getName();
StringRef DemangledName;
Op OC = OpNop;
Expand Down Expand Up @@ -648,6 +674,89 @@ void SPIRVToOCL::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
&Attrs);
}

void SPIRVToOCL::visitCallSPIRVOCLExt(CallInst *CI, OCLExtOpKind Kind) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
return OCLExtOpMap::map(Kind);
},
&Attrs);
}

void SPIRVToOCL::visitCallSPIRVVLoadn(CallInst *CI, OCLExtOpKind Kind) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
std::string Name = OCLExtOpMap::map(Kind);
if (ConstantInt *C = dyn_cast<ConstantInt>(Args.back())) {
uint64_t NumComponents = C->getZExtValue();
std::stringstream SS;
SS << NumComponents;
Name.replace(Name.find("n"), 1, SS.str());
}
Args.pop_back();
return Name;
},
&Attrs);
}

void SPIRVToOCL::visitCallSPIRVVStore(CallInst *CI, OCLExtOpKind Kind) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
std::string Name = OCLExtOpMap::map(Kind);
if (Kind == OpenCLLIB::Vstore_half_r ||
Kind == OpenCLLIB::Vstore_halfn_r ||
Kind == OpenCLLIB::Vstorea_halfn_r) {
auto C = cast<ConstantInt>(Args.back());
auto RoundingMode =
static_cast<SPIRVFPRoundingModeKind>(C->getZExtValue());
Name.replace(Name.find("_r"), 2,
std::string("_") +
SPIRSPIRVFPRoundingModeMap::rmap(RoundingMode));
Args.pop_back();
}

if (Kind == OpenCLLIB::Vstore_halfn ||
Kind == OpenCLLIB::Vstore_halfn_r ||
Kind == OpenCLLIB::Vstorea_halfn ||
Kind == OpenCLLIB::Vstorea_halfn_r || Kind == OpenCLLIB::Vstoren) {
if (auto DataType = dyn_cast<VectorType>(Args[0]->getType())) {
uint64_t NumElements = DataType->getElementCount().getValue();
assert((NumElements == 2 || NumElements == 3 || NumElements == 4 ||
NumElements == 8 || NumElements == 16) &&
"Unsupported vector size for vstore instruction!");
std::stringstream SS;
SS << NumElements;
Name.replace(Name.find("n"), 1, SS.str());
}
}

return Name;
},
&Attrs);
}

void SPIRVToOCL::visitCallSPIRVPrintf(CallInst *CI, OCLExtOpKind Kind) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
CallInst *NewCI = mutateCallInstOCL(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
return OCLExtOpMap::map(OpenCLLIB::Printf);
},
&Attrs);

// Clang represents printf function without mangling
std::string TargetName = "printf";
if (Function *F = M->getFunction(TargetName))
NewCI->setCalledFunction(F);
else
NewCI->getCalledFunction()->setName(TargetName);
}

std::string SPIRVToOCL::getGroupBuiltinPrefix(CallInst *CI) {
std::string Prefix;
auto ES = getArgAsScope(CI, 0);
Expand Down
13 changes: 13 additions & 0 deletions lib/SPIRV/SPIRVToOCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ class SPIRVToOCL : public ModulePass, public InstVisitor<SPIRVToOCL> {
/// No change with arguments.
void visitCallSPIRVBuiltin(CallInst *CI, Op OC);

/// Transform __spirv_ocl* instructions (OpenCL Extended Instruction Set)
/// to OpenCL builtins.
void visitCallSPIRVOCLExt(CallInst *CI, OCLExtOpKind Kind);

/// Transform __spirv_ocl_vstore* to corresponding vstore OpenCL instruction
void visitCallSPIRVVStore(CallInst *CI, OCLExtOpKind Kind);

/// Transform __spirv_ocl_vloadn to OpenCL vload[2|4|8|16]
void visitCallSPIRVVLoadn(CallInst *CI, OCLExtOpKind Kind);

/// Transform __spirv_ocl_printf to (i8 addrspace(2)*, ...) @printf
void visitCallSPIRVPrintf(CallInst *CI, OCLExtOpKind Kind);

/// Get prefix work_/sub_ for OCL group builtin functions.
/// Assuming the first argument of \param CI is a constant integer for
/// workgroup/subgroup scope enums.
Expand Down
29 changes: 29 additions & 0 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,35 @@ bool hasLoopMetadata(const Module *M) {
return false;
}

bool isSPIRVOCLExtInst(const CallInst *CI, OCLExtOpKind *ExtOp) {
StringRef DemangledName;
if (!oclIsBuiltin(CI->getCalledFunction()->getName(), DemangledName))
return false;
StringRef S = DemangledName;
if (!S.startswith(kSPIRVName::Prefix))
return false;
S = S.drop_front(strlen(kSPIRVName::Prefix));
auto Loc = S.find(kSPIRVPostfix::Divider);
auto ExtSetName = S.substr(0, Loc);
SPIRVExtInstSetKind Set = SPIRVEIS_Count;
if (!SPIRVExtSetShortNameMap::rfind(ExtSetName.str(), &Set))
return false;

if (Set != SPIRVEIS_OpenCL)
return false;

auto ExtOpName = S.substr(Loc + 1);
auto PostFixPos = ExtOpName.find("_R");
ExtOpName = ExtOpName.substr(0, PostFixPos);

OCLExtOpKind EOC;
if (!OCLExtOpMap::rfind(ExtOpName.str(), &EOC))
return false;

*ExtOp = EOC;
return true;
}

// Returns true if type(s) and number of elements (if vector) is valid
bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
switch (II->getIntrinsicID()) {
Expand Down
17 changes: 17 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,23 @@ class SPIRVExtInst : public SPIRVFunctionCallGeneric<OpExtInst, 5> {
return Index == 3;
}
}
std::vector<SPIRVValue *> getArgValues() {
std::vector<SPIRVValue *> VArgs;
for (size_t I = 0; I < Args.size(); ++I) {
if (isOperandLiteral(I))
VArgs.push_back(Module->getLiteralAsConstant(Args[I]));
else
VArgs.push_back(getValue(Args[I]));
}
return VArgs;
}
std::vector<SPIRVType *> getArgTypes() {
std::vector<SPIRVType *> ArgTypes;
auto VArgs = getArgValues();
for (auto VArg : VArgs)
ArgTypes.push_back(VArg->getType());
return ArgTypes;
}

protected:
SPIRVExtInstSetKind ExtSetKind;
Expand Down
Loading

0 comments on commit 67d3e27

Please sign in to comment.