Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Add support of [[vk::ext_type_def]] #4068

Merged
merged 6 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ clang::CXXRecordDecl* DeclareTemplateTypeWithHandle(

clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle(
clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName);
clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandleInDeclContext(
clang::ASTContext &context, clang::DeclContext *declContext,
llvm::StringRef typeName, llvm::StringRef templateParamName);
clang::CXXRecordDecl *DeclareConstantBufferViewType(clang::ASTContext& context, bool bTBuf);
clang::CXXRecordDecl* DeclareRayQueryType(clang::ASTContext& context);
clang::CXXRecordDecl *DeclareResourceType(clang::ASTContext &context,
Expand Down
8 changes: 8 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,14 @@ def VKReferenceExt : InheritableAttr {
let Documentation = [Undocumented];
}

def VKTypeDefExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_type_def">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Args = [UnsignedArgument<"id">, UnsignedArgument<"opcode">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

// Global variables that are of scalar type
def ScalarGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isScalarType()}]>;

Expand Down
7 changes: 7 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ class SpirvContext {
return rayQueryTypeKHR;
}

const SpirvIntrinsicType *
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);

SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);

/// --- Hybrid type getter functions ---
///
/// Concrete SpirvType objects represent a SPIR-V type completely. Hybrid
Expand Down Expand Up @@ -467,6 +473,7 @@ class SpirvContext {
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType*> spirvIntrinsicTypes;
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
const RayQueryTypeKHR *rayQueryTypeKHR;

Expand Down
12 changes: 6 additions & 6 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1110,10 +1110,13 @@ class SpirvConstant : public SpirvInstruction {
}

bool isSpecConstant() const;
void setLiteral(bool literal = true) { literalConstant = literal; }
bool isLiteral() { return literalConstant; }

protected:
SpirvConstant(Kind, spv::Op, const SpirvType *);
SpirvConstant(Kind, spv::Op, QualType);
SpirvConstant(Kind, spv::Op, const SpirvType *, bool literal = false);
SpirvConstant(Kind, spv::Op, QualType, bool literal = false);
bool literalConstant;
};

class SpirvConstantBoolean : public SpirvConstant {
Expand Down Expand Up @@ -1141,7 +1144,7 @@ class SpirvConstantBoolean : public SpirvConstant {
class SpirvConstantInteger : public SpirvConstant {
public:
SpirvConstantInteger(QualType type, llvm::APInt value,
bool isSpecConst = false, bool literal = false);
bool isSpecConst = false);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger)

Expand All @@ -1155,12 +1158,9 @@ class SpirvConstantInteger : public SpirvConstant {
bool invokeVisitor(Visitor *v) override;

llvm::APInt getValue() const { return value; }
void setLiteral(bool l = true) { isLiteral = l; }
bool getLiteral() { return isLiteral; }

private:
llvm::APInt value;
bool isLiteral;
};

class SpirvConstantFloat : public SpirvConstant {
Expand Down
32 changes: 32 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvType.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class SpirvType {
TK_Function,
TK_AccelerationStructureNV,
TK_RayQueryKHR,
TK_SpirvIntrinsicType,
// Order matters: all the following are hybrid types
TK_HybridStruct,
TK_HybridPointer,
Expand Down Expand Up @@ -412,6 +413,37 @@ class RayQueryTypeKHR : public SpirvType {
}
};

class SpirvInstruction;
struct SpvIntrinsicTypeOperand {
SpvIntrinsicTypeOperand(SpirvType *type_operand)
: operand_as_type(type_operand), isTypeOperand(true) {}
SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand)
: operand_as_inst(inst_operand), isTypeOperand(false) {}
union {
SpirvType *operand_as_type;
SpirvInstruction *operand_as_inst;
};
bool isTypeOperand;
};

class SpirvIntrinsicType : public SpirvType {
public:
SpirvIntrinsicType(unsigned typeOp,
llvm::ArrayRef<SpvIntrinsicTypeOperand> inOps);

static bool classof(const SpirvType *t) {
return t->getKind() == TK_SpirvIntrinsicType;
}
unsigned getOpCode() const { return typeOpCode; }
llvm::ArrayRef<SpvIntrinsicTypeOperand> getOperands() const {
return operands;
}

private:
unsigned typeOpCode;
llvm::SmallVector<SpvIntrinsicTypeOperand, 3> operands;
};

class HybridType : public SpirvType {
public:
static bool classof(const SpirvType *t) {
Expand Down
9 changes: 8 additions & 1 deletion tools/clang/lib/AST/ASTContextHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,15 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(

CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle(
ASTContext& context, StringRef typeName, StringRef templateParamName) {
return DeclareUIntTemplatedTypeWithHandleInDeclContext(
context, context.getTranslationUnitDecl(), typeName, templateParamName);
}

CXXRecordDecl *hlsl::DeclareUIntTemplatedTypeWithHandleInDeclContext(
ASTContext &context, DeclContext *declContext, StringRef typeName,
StringRef templateParamName) {
// template<uint kind> FeedbackTexture2D[Array] { ... }
BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName);
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName);
typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy);
typeDeclBuilder.startDefinition();
typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
Expand Down
5 changes: 3 additions & 2 deletions tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,9 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
}
case spv::Op::OpRayQueryInitializeKHR: {
auto rayQueryInst = dyn_cast<SpirvRayQueryOpKHR>(instr);
if (rayQueryInst->hasCullFlags()) {
addCapability(spv::Capability::RayTraversalPrimitiveCullingKHR);
if (rayQueryInst && rayQueryInst->hasCullFlags()) {
addCapability(
spv::Capability::RayTraversalPrimitiveCullingKHR);
}

break;
Expand Down
69 changes: 65 additions & 4 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,10 +1884,9 @@ bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) {
}

for (const auto operand : inst->getOperands()) {
// TODO: Handle Literals with other types.
auto literalOperand = dyn_cast<SpirvConstantInteger>(operand);
if (literalOperand && literalOperand->getLiteral()) {
curInst.push_back(literalOperand->getValue().getZExtValue());
auto literalOperand = dyn_cast<SpirvConstant>(operand);
if (literalOperand && literalOperand->isLiteral()) {
typeHandler.emitLiteral(literalOperand, curInst);
} else {
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
}
Expand Down Expand Up @@ -2451,6 +2450,24 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
initTypeInstruction(spv::Op::OpTypeRayQueryKHR);
curTypeInst.push_back(id);
finalizeTypeInstruction();
} else if (const auto *spvIntrinsicType =
dyn_cast<SpirvIntrinsicType>(type)) {
initTypeInstruction(static_cast<spv::Op>(spvIntrinsicType->getOpCode()));
curTypeInst.push_back(id);
for (const SpvIntrinsicTypeOperand &operand :
spvIntrinsicType->getOperands()) {
if (operand.isTypeOperand) {
curTypeInst.push_back(emitType(operand.operand_as_type));
} else {
auto *literal = dyn_cast<SpirvConstant>(operand.operand_as_inst);
if (literal && literal->isLiteral()) {
emitLiteral(literal, curTypeInst);
} else {
curTypeInst.push_back(getOrAssignResultId(operand.operand_as_inst));
}
}
}
finalizeTypeInstruction();
}
// Hybrid Types
// Note: The type lowering pass should lower all types to SpirvTypes.
Expand All @@ -2467,6 +2484,50 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
return id;
}

template <typename vecType>
void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral,
vecType &outInst) {
const auto &literalVal = intLiteral->getValue();
bool positive = !literalVal.isNegative();
if (literalVal.getBitWidth() <= 32) {
outInst.push_back(positive ? literalVal.getZExtValue()
: literalVal.getSExtValue());
} else {
assert(literalVal.getBitWidth() == 64);
uint64_t val =
positive ? literalVal.getZExtValue() : literalVal.getSExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
}

template <typename vecType>
void EmitTypeHandler::emitFloatLiteral(const SpirvConstantFloat *fLiteral,
vecType &outInst) {
const auto &literalVal = fLiteral->getValue();
const auto bitwidth = llvm::APFloat::getSizeInBits(literalVal.getSemantics());
if (bitwidth <= 32) {
outInst.push_back(literalVal.bitcastToAPInt().getZExtValue());
} else {
assert(bitwidth == 64);
uint64_t val = literalVal.bitcastToAPInt().getZExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
}

template <typename VecType>
void EmitTypeHandler::emitLiteral(const SpirvConstant *literal,
VecType &outInst) {
if (auto boolLiteral = dyn_cast<SpirvConstantBoolean>(literal)) {
outInst.push_back(static_cast<unsigned>(boolLiteral->getValue()));
} else if (auto intLiteral = dyn_cast<SpirvConstantInteger>(literal)) {
emitIntLiteral(intLiteral, outInst);
} else if (auto fLiteral = dyn_cast<SpirvConstantFloat>(literal)) {
emitFloatLiteral(fLiteral, outInst);
}
}

void EmitTypeHandler::emitDecoration(uint32_t typeResultId,
spv::Decoration decoration,
llvm::ArrayRef<uint32_t> decorationParams,
Expand Down
6 changes: 6 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ class EmitTypeHandler {
uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
uint32_t getOrCreateConstantNull(SpirvConstantNull *);
uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
template <typename vecType>
void emitLiteral(const SpirvConstant *, vecType &outInst);
template <typename vecType>
void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst);
template <typename vecType>
void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst);

private:
void initTypeInstruction(spv::Op op);
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
if (name == "RayQuery")
return spvContext.getRayQueryTypeKHR();

if (name == "ext_type") {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(type);
return spvContext.getCreatedSpirvIntrinsicType(typeId);
}

if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
// StructureBuffer<S> will be translated into an OpTypeStruct with one
Expand Down
7 changes: 5 additions & 2 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,10 +1001,13 @@ SpirvInstruction *SpirvBuilder::createSpirvIntrInstExt(

SpirvExtInstImport *set =
(instSet.size() == 0) ? nullptr : getExtInstSet(instSet);

if (retType != QualType() && retType->isVoidType()) {
retType = QualType();
}

auto *inst = new (context) SpirvIntrinsicInstruction(
retType->isVoidType() ? QualType() : retType, opcode, operands,
extensions, set, capablities, loc);
retType, opcode, operands, extensions, set, capablities, loc);
insertPoint->addInstruction(inst);
return inst;
}
Expand Down
23 changes: 23 additions & 0 deletions tools/clang/lib/SPIRV/SpirvContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ SpirvContext::~SpirvContext() {

for (auto &typePair : typeTemplateParams)
typePair.second->releaseMemory();

for (auto &pair : spirvIntrinsicTypes) {
assert(pair.second);
pair.second->~SpirvIntrinsicType();
}
}

inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
Expand Down Expand Up @@ -527,5 +532,23 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) {
typeTemplateParams.clear();
}

const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType(
unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands) {
if (spirvIntrinsicTypes[typeId] == nullptr) {
spirvIntrinsicTypes[typeId] =
new (this) SpirvIntrinsicType(typeOpCode, operands);
}
return spirvIntrinsicTypes[typeId];
}

SpirvIntrinsicType *
SpirvContext::getCreatedSpirvIntrinsicType(unsigned typeId) {
if (spirvIntrinsicTypes.find(typeId) == spirvIntrinsicTypes.end()){
return nullptr;
}
return spirvIntrinsicTypes[typeId];
}

} // end namespace spirv
} // end namespace clang
Loading