From 374b714ae2b93720521d0b7fb77a798d7fe4cb98 Mon Sep 17 00:00:00 2001 From: Lu Jiao Date: Thu, 28 Oct 2021 10:41:59 +0800 Subject: [PATCH 1/6] [SPIRV] Add support of [[vk::ext_type_def]] this is related /~https://github.com/microsoft/DirectXShaderCompiler/issues/3919 --- tools/clang/include/clang/Basic/Attr.td | 8 +++ .../clang/include/clang/SPIRV/SpirvContext.h | 8 +++ .../include/clang/SPIRV/SpirvInstruction.h | 12 ++-- tools/clang/include/clang/SPIRV/SpirvType.h | 19 +++++++ tools/clang/lib/SPIRV/CapabilityVisitor.cpp | 5 +- tools/clang/lib/SPIRV/EmitVisitor.cpp | 51 +++++++++++++++-- tools/clang/lib/SPIRV/EmitVisitor.h | 2 + tools/clang/lib/SPIRV/LowerTypeVisitor.cpp | 5 ++ tools/clang/lib/SPIRV/SpirvBuilder.cpp | 7 ++- tools/clang/lib/SPIRV/SpirvContext.cpp | 19 +++++++ tools/clang/lib/SPIRV/SpirvEmitter.cpp | 57 +++++++++++++++++-- tools/clang/lib/SPIRV/SpirvEmitter.h | 3 + tools/clang/lib/SPIRV/SpirvInstruction.cpp | 16 ++++-- tools/clang/lib/SPIRV/SpirvType.cpp | 9 ++- tools/clang/lib/Sema/SemaHLSL.cpp | 26 +++++++-- .../spv.intrinsicTypeInteger.hlsl | 21 +++++++ .../spv.intrinsicTypeRayquery.hlsl | 40 +++++++++++++ .../unittests/SPIRV/CodeGenSpirvTest.cpp | 2 + 18 files changed, 280 insertions(+), 30 deletions(-) create mode 100644 tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 9a0e89a32c..3a9f0cc8cc 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -1145,6 +1145,14 @@ def VKReferenceExt : InheritableAttr { let Documentation = [Undocumented]; } +def VKTypeDefExt : InheritableAttr { + let Spellings = [CXX11<"vk", "ext_type_def">]; + let Subjects = SubjectList<[Function], ErrorDiag>; + let Args = [UnsignedArgument<"id">, UnsignedArgument<"opcode">]; + let LangOpts = [SPIRV]; + let Documentation = [Undocumented]; +} + // Global variables that are of scalar type def ScalarGlobalVar : SubsetSubjecthasGlobalStorage() && S->getType()->isScalarType()}]>; diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index 7e730cc491..d3baef3c2b 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -288,6 +288,13 @@ class SpirvContext { return rayQueryTypeKHR; } + const SpirvIntrinsicType * + getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, + llvm::ArrayRef constants, + SpirvIntrinsicType *elementTy); + + SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId); + /// --- Hybrid type getter functions --- /// /// Concrete SpirvType objects represent a SPIR-V type completely. Hybrid @@ -467,6 +474,7 @@ class SpirvContext { llvm::DenseMap pointerTypes; llvm::SmallVector hybridPointerTypes; llvm::DenseSet functionTypes; + llvm::DenseMap spirvIntrinsicTypes; const AccelerationStructureTypeNV *accelerationStructureTypeNV; const RayQueryTypeKHR *rayQueryTypeKHR; diff --git a/tools/clang/include/clang/SPIRV/SpirvInstruction.h b/tools/clang/include/clang/SPIRV/SpirvInstruction.h index fba2316496..0e77458c81 100644 --- a/tools/clang/include/clang/SPIRV/SpirvInstruction.h +++ b/tools/clang/include/clang/SPIRV/SpirvInstruction.h @@ -1110,10 +1110,13 @@ class SpirvConstant : public SpirvInstruction { } bool isSpecConstant() const; + void setLiteral(bool literal = true) { literalConstant = literal; } + bool isLiteral() { return literalConstant; } protected: - SpirvConstant(Kind, spv::Op, const SpirvType *); - SpirvConstant(Kind, spv::Op, QualType); + SpirvConstant(Kind, spv::Op, const SpirvType *, bool literal = false); + SpirvConstant(Kind, spv::Op, QualType, bool literal = false); + bool literalConstant; }; class SpirvConstantBoolean : public SpirvConstant { @@ -1141,7 +1144,7 @@ class SpirvConstantBoolean : public SpirvConstant { class SpirvConstantInteger : public SpirvConstant { public: SpirvConstantInteger(QualType type, llvm::APInt value, - bool isSpecConst = false, bool literal = false); + bool isSpecConst = false); DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger) @@ -1155,12 +1158,9 @@ class SpirvConstantInteger : public SpirvConstant { bool invokeVisitor(Visitor *v) override; llvm::APInt getValue() const { return value; } - void setLiteral(bool l = true) { isLiteral = l; } - bool getLiteral() { return isLiteral; } private: llvm::APInt value; - bool isLiteral; }; class SpirvConstantFloat : public SpirvConstant { diff --git a/tools/clang/include/clang/SPIRV/SpirvType.h b/tools/clang/include/clang/SPIRV/SpirvType.h index 2158eb540a..44a877517c 100644 --- a/tools/clang/include/clang/SPIRV/SpirvType.h +++ b/tools/clang/include/clang/SPIRV/SpirvType.h @@ -49,6 +49,7 @@ class SpirvType { TK_Function, TK_AccelerationStructureNV, TK_RayQueryKHR, + TK_SpirvIntrinsicType, // Order matters: all the following are hybrid types TK_HybridStruct, TK_HybridPointer, @@ -412,6 +413,24 @@ class RayQueryTypeKHR : public SpirvType { } }; +class SpirvConstant; +class SpirvIntrinsicType : public SpirvType { +public: + SpirvIntrinsicType(unsigned typeOp, llvm::ArrayRef constants, + SpirvIntrinsicType *elementTy); + static bool classof(const SpirvType *t) { + return t->getKind() == TK_SpirvIntrinsicType; + } + unsigned getOpCode() const { return typeOpCode; } + llvm::ArrayRef getLiterals() const { return literals; } + SpirvIntrinsicType *getElemType() const { return elementType; } + +private: + unsigned typeOpCode; + llvm::SmallVector literals; + SpirvIntrinsicType *elementType; +}; + class HybridType : public SpirvType { public: static bool classof(const SpirvType *t) { diff --git a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp index a3250a7b83..aed9adf0a2 100644 --- a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp +++ b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp @@ -529,8 +529,9 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) { } case spv::Op::OpRayQueryInitializeKHR: { auto rayQueryInst = dyn_cast(instr); - if (rayQueryInst->hasCullFlags()) { - addCapability(spv::Capability::RayTraversalPrimitiveCullingKHR); + if (rayQueryInst && rayQueryInst->hasCullFlags()) { + addCapability( + spv::Capability::RayTraversalPrimitiveCullingKHR); } break; diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index cd72488a72..56459d539d 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -1884,10 +1884,9 @@ bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) { } for (const auto operand : inst->getOperands()) { - // TODO: Handle Literals with other types. - auto literalOperand = dyn_cast(operand); - if (literalOperand && literalOperand->getLiteral()) { - curInst.push_back(literalOperand->getValue().getZExtValue()); + auto literalOperand = dyn_cast(operand); + if (literalOperand && literalOperand->isLiteral()) { + typeHandler.emitLiteral(literalOperand, curInst); } else { curInst.push_back(getOrAssignResultId(operand)); } @@ -2452,6 +2451,17 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { curTypeInst.push_back(id); finalizeTypeInstruction(); } + else if (const auto *spvIntrinsicType = dyn_cast(type)) { + initTypeInstruction(static_cast(spvIntrinsicType->getOpCode())); + curTypeInst.push_back(id); + if (spvIntrinsicType->getElemType()) { + curTypeInst.push_back(emitType(spvIntrinsicType->getElemType())); + } + for (auto& literal : spvIntrinsicType->getLiterals()) { + emitLiteral(literal, curTypeInst); + } + finalizeTypeInstruction(); + } // Hybrid Types // Note: The type lowering pass should lower all types to SpirvTypes. // Therefore, if we find a hybrid type when going through the emitting pass, @@ -2467,6 +2477,39 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { return id; } +template +void EmitTypeHandler::emitLiteral(const SpirvConstant *literal, + VecType &outInst) { + if (auto boolLiteral = dyn_cast(literal)) { + outInst.push_back(static_cast(boolLiteral->getValue())); + } else if (auto intLiteral = dyn_cast(literal)) { + const auto &literalVal = intLiteral->getValue(); + bool positive = !literalVal.isNegative(); + if (literalVal.getBitWidth() <= 32) { + outInst.push_back(positive ? literalVal.getZExtValue() + : literalVal.getSExtValue()); + } else { + assert(literalVal.getBitWidth() == 64); + uint64_t val = + positive ? literalVal.getZExtValue() : literalVal.getSExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } + } else if (auto fLiteral = dyn_cast(literal)) { + const auto &literalVal = fLiteral->getValue(); + const auto bitwidth = + llvm::APFloat::getSizeInBits(literalVal.getSemantics()); + if (bitwidth <= 32) { + outInst.push_back(literalVal.bitcastToAPInt().getZExtValue()); + } else { + assert(bitwidth == 64); + uint64_t val = literalVal.bitcastToAPInt().getZExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } + } +} + void EmitTypeHandler::emitDecoration(uint32_t typeResultId, spv::Decoration decoration, llvm::ArrayRef decorationParams, diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index 38ec433818..38b2431fc1 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -109,6 +109,8 @@ class EmitTypeHandler { uint32_t getOrCreateConstantComposite(SpirvConstantComposite *); uint32_t getOrCreateConstantNull(SpirvConstantNull *); uint32_t getOrCreateConstantBool(SpirvConstantBoolean *); + template + void emitLiteral(const SpirvConstant *, vecType &outInst); private: void initTypeInstruction(spv::Op op); diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index 92aa8f67e1..ad4379ff2f 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -593,6 +593,11 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type, if (name == "RayQuery") return spvContext.getRayQueryTypeKHR(); + if (name == "ext_type") { + auto typeId = hlsl::GetHLSLResourceTemplateUInt(type); + return spvContext.getCreatedSpirvIntrinsicType(typeId); + } + if (name == "StructuredBuffer" || name == "RWStructuredBuffer" || name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") { // StructureBuffer will be translated into an OpTypeStruct with one diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index ccf74b8c1c..cfef113a92 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -1001,10 +1001,13 @@ SpirvInstruction *SpirvBuilder::createSpirvIntrInstExt( SpirvExtInstImport *set = (instSet.size() == 0) ? nullptr : getExtInstSet(instSet); + + if (retType != QualType() && retType->isVoidType()) { + retType = QualType(); + } auto *inst = new (context) SpirvIntrinsicInstruction( - retType->isVoidType() ? QualType() : retType, opcode, operands, - extensions, set, capablities, loc); + retType, opcode, operands, extensions, set, capablities, loc); insertPoint->addInstruction(inst); return inst; } diff --git a/tools/clang/lib/SPIRV/SpirvContext.cpp b/tools/clang/lib/SPIRV/SpirvContext.cpp index 9f4c988120..bccb47dd43 100644 --- a/tools/clang/lib/SPIRV/SpirvContext.cpp +++ b/tools/clang/lib/SPIRV/SpirvContext.cpp @@ -527,5 +527,24 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) { typeTemplateParams.clear(); } +const SpirvIntrinsicType * +SpirvContext::getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, + llvm::ArrayRef constants, + SpirvIntrinsicType *elementTy) { + if (spirvIntrinsicTypes[typeId] == nullptr) { + spirvIntrinsicTypes[typeId] = + new (this) SpirvIntrinsicType(typeOpCode, constants, elementTy); + } + return spirvIntrinsicTypes[typeId]; +} + +SpirvIntrinsicType * +SpirvContext::getCreatedSpirvIntrinsicType(unsigned typeId) { + if (spirvIntrinsicTypes.find(typeId) == spirvIntrinsicTypes.end()){ + return nullptr; + } + return spirvIntrinsicTypes[typeId]; +} + } // end namespace spirv } // end namespace clang diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index eeea430520..666c217dbe 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -18,9 +18,9 @@ #include "dxc/HlslIntrinsicOp.h" #include "spirv-tools/optimizer.hpp" #include "clang/SPIRV/AstTypeProbe.h" +#include "clang/SPIRV/String.h" #include "clang/Sema/Sema.h" #include "llvm/ADT/StringExtras.h" - #include "InitListHandler.h" #include "dxc/DXIL/DxilConstants.h" @@ -2337,8 +2337,11 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) { return doCXXMemberCallExpr(memberCall); auto funcDecl = callExpr->getDirectCallee(); - if (funcDecl && funcDecl->hasAttr()) { - return processSpvIntrinsicCallExpr(callExpr); + if (funcDecl) { + if (funcDecl->hasAttr()) + return processSpvIntrinsicCallExpr(callExpr); + else if(funcDecl->hasAttr()) + return processSpvIntrinsicTypeDef(callExpr); } // Intrinsic functions such as 'dot' or 'mul' if (hlsl::IsIntrinsicOp(funcDecl)) { @@ -12530,7 +12533,7 @@ SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) { } spvArgs.push_back(argInst); } else if (param->hasAttr()) { - auto constArg = dyn_cast(argInst); + auto constArg = dyn_cast(argInst); assert(constArg != nullptr); constArg->setLiteral(); spvArgs.push_back(argInst); @@ -12601,6 +12604,52 @@ SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) { execModesParams, expr->getExprLoc()); } +SpirvInstruction * +SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { + auto funcDecl = expr->getDirectCallee(); + auto typeDefAttr = funcDecl->getAttr(); + SpirvIntrinsicType *elementType = nullptr; + SmallVector constants; + llvm::SmallVector capbilities; + llvm::SmallVector extensions; + + for (auto &attr : funcDecl->getAttrs()) { + if (auto capAttr = dyn_cast(attr)) { + capbilities.push_back(capAttr->getCapability()); + } else if (auto extAttr = dyn_cast(attr)) { + extensions.push_back(extAttr->getName()); + } + } + + const auto args = expr->getArgs(); + for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { + auto param = funcDecl->getParamDecl(i); + const Expr *arg = args[i]->IgnoreParenLValueCasts(); + if (param->hasAttr()) { + auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType()); + elementType = spvContext.getCreatedSpirvIntrinsicType(typeId); + } else if (param->hasAttr()) { + SpirvInstruction *argInst = doExpr(arg); + auto constArg = dyn_cast(argInst); + assert(constArg != nullptr); + constArg->setLiteral(); + constants.push_back(constArg); + } + } + spvContext.getSpirvIntrinsicType( + typeDefAttr->getId(), typeDefAttr->getOpcode(), constants, elementType); + + // Emit dummy OpNop with no semantic meaning, with possible extension and + // capabilities + + SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt( + static_cast(spv::Op::OpNop), QualType(), {}, extensions, {}, + capbilities, expr->getExprLoc()); + retVal->setRValue(); + + return retVal; +} + bool SpirvEmitter::spirvToolsValidate(std::vector *mod, std::string *messages) { spvtools::SpirvTools tools(featureManager.getTargetEnv()); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 900c63da82..521d44bf76 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -596,6 +596,9 @@ class SpirvEmitter : public ASTConsumer { hlsl::IntrinsicOp opcode); /// Process spirv intrinsic instruction SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr); + + /// Process spirv intrinsic type definition + SpirvInstruction *processSpvIntrinsicTypeDef(const CallExpr *expr); /// Custom intrinsic to support basic buffer_reference use case SpirvInstruction *processRawBufferLoad(const CallExpr *callExpr); diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index 1d4ee08c32..fb4c6cf887 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -476,15 +476,19 @@ SpirvCompositeConstruct::SpirvCompositeConstruct( resultType, loc), consituents(constituentsVec.begin(), constituentsVec.end()) {} -SpirvConstant::SpirvConstant(Kind kind, spv::Op op, const SpirvType *spvType) +SpirvConstant::SpirvConstant(Kind kind, spv::Op op, const SpirvType *spvType, + bool literal) : SpirvInstruction(kind, op, QualType(), - /*SourceLocation*/ {}) { + /*SourceLocation*/ {}), + literalConstant(literal) { setResultType(spvType); } -SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType) +SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType, + bool literal) : SpirvInstruction(kind, op, resultType, - /*SourceLocation*/ {}) {} + /*SourceLocation*/ {}), + literalConstant(literal) {} bool SpirvConstant::isSpecConstant() const { return opcode == spv::Op::OpSpecConstant || @@ -509,11 +513,11 @@ bool SpirvConstantBoolean::operator==(const SpirvConstantBoolean &that) const { } SpirvConstantInteger::SpirvConstantInteger(QualType type, llvm::APInt val, - bool isSpecConst, bool literal) + bool isSpecConst) : SpirvConstant(IK_ConstantInteger, isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant, type), - value(val), isLiteral(literal) { + value(val) { assert(type->isIntegerType()); } diff --git a/tools/clang/lib/SPIRV/SpirvType.cpp b/tools/clang/lib/SPIRV/SpirvType.cpp index d155767714..91ba944416 100644 --- a/tools/clang/lib/SPIRV/SpirvType.cpp +++ b/tools/clang/lib/SPIRV/SpirvType.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "clang/SPIRV/SpirvType.h" - +#include "clang/SPIRV/SpirvInstruction.h" #include namespace clang { @@ -167,6 +167,13 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const { (!stride.hasValue() || stride.getValue() == that.stride.getValue()); } +SpirvIntrinsicType::SpirvIntrinsicType( + unsigned typeOp, llvm::ArrayRef constants, + SpirvIntrinsicType *eleTy) + : SpirvType(TK_SpirvIntrinsicType, "spirvIntrinsicType"), + typeOpCode(typeOp), literals(constants.begin(), constants.end()), + elementType(eleTy) {} + StructType::StructType(llvm::ArrayRef fieldsVec, llvm::StringRef name, bool isReadOnly, StructInterfaceType iface) diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index bf240bb27a..8569667da6 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -181,6 +181,7 @@ enum ArBasicKind { #ifdef ENABLE_SPIRV_CODEGEN AR_OBJECT_VK_SUBPASS_INPUT, AR_OBJECT_VK_SUBPASS_INPUT_MS, + AR_OBJECT_VK_SPV_INTRINSIC_TYPE, #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -472,6 +473,7 @@ const UINT g_uBasicKindProps[] = #ifdef ENABLE_SPIRV_CODEGEN BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT_MS + BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE use recordType #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1395,6 +1397,7 @@ const ArBasicKind g_ArBasicKindsAsTypes[] = #ifdef ENABLE_SPIRV_CODEGEN AR_OBJECT_VK_SUBPASS_INPUT, AR_OBJECT_VK_SUBPASS_INPUT_MS, + AR_OBJECT_VK_SPV_INTRINSIC_TYPE, #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1486,7 +1489,8 @@ const uint8_t g_ArBasicKindsTemplateCount[] = // SPIRV change starts #ifdef ENABLE_SPIRV_CODEGEN 1, // AR_OBJECT_VK_SUBPASS_INPUT - 1, // AR_OBJECT_VK_SUBPASS_INPUT_MS + 1, // AR_OBJECT_VK_SUBPASS_INPUT_MS, + 1, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1587,6 +1591,7 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] = #ifdef ENABLE_SPIRV_CODEGEN { 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SUBPASS_INPUT (SubpassInput) { 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SUBPASS_INPUT_MS (SubpassInputMS) + { 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1706,6 +1711,7 @@ const char* g_ArBasicTypeNames[] = #ifdef ENABLE_SPIRV_CODEGEN "SubpassInput", "SubpassInputMS", + "ext_type", #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -3588,11 +3594,14 @@ class HLSLExternalSource : public ExternalSemaSource { else if (kind == AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY) { recordDecl = DeclareUIntTemplatedTypeWithHandle(*m_context, "FeedbackTexture2DArray", "kind"); } +#ifdef ENABLE_SPIRV_CODEGEN + else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE) { + recordDecl = DeclareUIntTemplatedTypeWithHandle(*m_context, "ext_type", "id"); + } +#endif else if (templateArgCount == 0) { recordDecl = DeclareRecordTypeWithHandle(*m_context, typeName); - } - else - { + } else { DXASSERT(templateArgCount == 1 || templateArgCount == 2, "otherwise a new case has been added"); TypeSourceInfo* typeDefault = TemplateHasDefaultType(kind) ? float4TypeSourceInfo : nullptr; @@ -12095,6 +12104,12 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, A.getRange(), S.Context, unsigned(ValidateAttributeIntArg(S, A)), A.getAttributeSpellingListIndex()); break; + case AttributeList::AT_VKTypeDefExt: + declAttr = ::new (S.Context) VKTypeDefExtAttr( + A.getRange(), S.Context, unsigned(ValidateAttributeIntArg(S, A)), + unsigned(ValidateAttributeIntArg(S, A, 1)), + A.getAttributeSpellingListIndex()); + break; default: Handled = false; return; @@ -12877,7 +12892,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth, // Validate that Vulkan specific feature is only used when targeting SPIR-V if (!getLangOpts().SPIRV) { if (basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT || - basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT_MS) { + basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT_MS || + basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_TYPE) { Diag(D.getLocStart(), diag::err_hlsl_vulkan_specific_feature) << g_ArBasicTypeNames[basicKind]; result = false; diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl new file mode 100644 index 0000000000..247f188baf --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl @@ -0,0 +1,21 @@ +// RUN: %dxc -T ps_6_0 -E main -spirv + +[[vk::ext_type_def(0, 21)]] +void createTypeInt([[vk::ext_literal]] int sizeInBits, + [[vk::ext_literal]] int signedness); + +[[vk::ext_type_def(1, 23)]] +void createTypeVector([[vk::ext_reference]] ext_type<0> typeInt, + [[vk::ext_literal]] int componentCount); + +//CHECK: %spirvIntrinsicType = OpTypeInt 32 0 +//CHECK: %spirvIntrinsicType_0 = OpTypeVector %spirvIntrinsicType 4 + +ext_type<0> foo1; +ext_type<1> foo2; +float main() : SV_Target +{ + createTypeInt(32, 0); + createTypeVector(foo1, 4); + return 0.0; +} diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl new file mode 100644 index 0000000000..f18ffa04c6 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl @@ -0,0 +1,40 @@ +// RUN: %dxc -T cs_6_5 -E main -spirv + +[[vk::ext_capability(/* RayQueryKHR */ 4472)]] +[[vk::ext_extension("SPV_KHR_ray_query")]] +[[vk::ext_type_def(/* Unique id for type */ 2, + /* OpTypeRayQueryKHR */ 4472)]] +void createTypeRayQueryKHR(); + +[[vk::ext_type_def(/* Unique id for type */ 3, + /* OpTypeAccelerationStructureKHR */ 5341)]] +void createAcceleStructureType(); + +[[vk::ext_instruction(/* OpRayQueryTerminateKHR */ 4474)]] +void rayQueryTerminateEXT( + [[vk::ext_reference]] ext_type<2> rq); + +ext_type<3> as : register(t0); + +[[vk::ext_instruction(/* OpRayQueryInitializeKHR */ 4473)]] +void rayQueryInitializeEXT([[vk::ext_reference]] ext_type<2> rayQuery, ext_type<3> as, uint rayFlags, uint cullMask, float3 origin, float tMin, float3 direction, float tMax); + +[[vk::ext_instruction(/* OpRayQueryTerminateKHR */ 4474)]] +void rayQueryTerminateEXT( + [[vk::ext_reference]] ext_type<2> rq ); + +//CHECK: %spirvIntrinsicType = OpTypeAccelerationStructureKHR +//CHECK: %spirvIntrinsicType_0 = OpTypeRayQueryKHR + +//CHECK: OpRayQueryInitializeKHR %rq {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} +//CHECK: OpRayQueryTerminateKHR %rq + +[numthreads(64, 1, 1)] +void main() +{ + createTypeRayQueryKHR(); + createAcceleStructureType(); + ext_type<2> rq; + rayQueryInitializeEXT(rq, as, 0, 0, float3(0, 0, 0), 0.0, float3(1,1,1), 1.0); + rayQueryTerminateEXT(rq); +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index fbf27c0355..37e248787b 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -1348,6 +1348,8 @@ TEST_F(FileTest, IntrinsicsSpirv) { runFileTest("spv.intrinsicDecorate.hlsl", Expect::Success, false); runFileTest("spv.intrinsicExecutionMode.hlsl", Expect::Success, false); runFileTest("spv.intrinsicStorageClass.hlsl", Expect::Success, false); + runFileTest("spv.intrinsicTypeInteger.hlsl"); + runFileTest("spv.intrinsicTypeRayquery.hlsl", Expect::Success, false); runFileTest("spv.intrinsic.reference.error.hlsl", Expect::Failure); } TEST_F(FileTest, IntrinsicsVkReadClock) { From 1fb3fd3b7d53e4a5149f332dd79a6aedf304d1e5 Mon Sep 17 00:00:00 2001 From: Jaebaek Seo Date: Wed, 17 Nov 2021 17:24:03 -0500 Subject: [PATCH 2/6] vk namespace --- tools/clang/include/clang/AST/HlslTypes.h | 3 +++ tools/clang/lib/AST/ASTContextHLSL.cpp | 9 ++++++++- tools/clang/lib/Sema/SemaHLSL.cpp | 20 ++++++++++++------- .../spv.intrinsicTypeInteger.hlsl | 6 +++--- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index 73fd883d0c..9f01dd6183 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -340,6 +340,9 @@ clang::CXXRecordDecl* DeclareTemplateTypeWithHandle( clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle( clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName); +clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandleInDeclContext( + clang::ASTContext &context, clang::DeclContext *declContext, + llvm::StringRef typeName, llvm::StringRef templateParamName); clang::CXXRecordDecl *DeclareConstantBufferViewType(clang::ASTContext& context, bool bTBuf); clang::CXXRecordDecl* DeclareRayQueryType(clang::ASTContext& context); clang::CXXRecordDecl *DeclareResourceType(clang::ASTContext &context, diff --git a/tools/clang/lib/AST/ASTContextHLSL.cpp b/tools/clang/lib/AST/ASTContextHLSL.cpp index 6a6e9a8ded..f849e8ae63 100644 --- a/tools/clang/lib/AST/ASTContextHLSL.cpp +++ b/tools/clang/lib/AST/ASTContextHLSL.cpp @@ -840,8 +840,15 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams( CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle( ASTContext& context, StringRef typeName, StringRef templateParamName) { + return DeclareUIntTemplatedTypeWithHandleInDeclContext( + context, context.getTranslationUnitDecl(), typeName, templateParamName); +} + +CXXRecordDecl *hlsl::DeclareUIntTemplatedTypeWithHandleInDeclContext( + ASTContext &context, DeclContext *declContext, StringRef typeName, + StringRef templateParamName) { // template FeedbackTexture2D[Array] { ... } - BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName); + BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName); typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy); typeDeclBuilder.startDefinition(); typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle. diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 8569667da6..a3387173cf 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -3596,7 +3596,9 @@ class HLSLExternalSource : public ExternalSemaSource { } #ifdef ENABLE_SPIRV_CODEGEN else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE) { - recordDecl = DeclareUIntTemplatedTypeWithHandle(*m_context, "ext_type", "id"); + recordDecl = DeclareUIntTemplatedTypeWithHandleInDeclContext( + *m_context, m_vkNSDecl, typeName, "id"); + recordDecl->setImplicit(true); } #endif else if (templateArgCount == 0) { @@ -3721,12 +3723,6 @@ class HLSLExternalSource : public ExternalSemaSource { m_sema = &S; S.addExternalSource(this); - AddObjectTypes(); - AddStdIsEqualImplementation(context, S); - for (auto && intrinsic : m_intrinsicTables) { - AddIntrinsicTableMethods(intrinsic); - } - #ifdef ENABLE_SPIRV_CODEGEN if (m_sema->getLangOpts().SPIRV) { // Create the "vk" namespace which contains Vulkan-specific intrinsics. @@ -3736,7 +3732,17 @@ class HLSLExternalSource : public ExternalSemaSource { SourceLocation(), &context.Idents.get("vk"), /*PrevDecl*/ nullptr); context.getTranslationUnitDecl()->addDecl(m_vkNSDecl); + } +#endif // ENABLE_SPIRV_CODEGEN + AddObjectTypes(); + AddStdIsEqualImplementation(context, S); + for (auto &&intrinsic : m_intrinsicTables) { + AddIntrinsicTableMethods(intrinsic); + } + +#ifdef ENABLE_SPIRV_CODEGEN + if (m_sema->getLangOpts().SPIRV) { // Add Vulkan-specific intrinsics. AddVkIntrinsicFunctions(); AddVkIntrinsicConstants(); diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl index 247f188baf..8f0c5d5431 100644 --- a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl @@ -5,14 +5,14 @@ void createTypeInt([[vk::ext_literal]] int sizeInBits, [[vk::ext_literal]] int signedness); [[vk::ext_type_def(1, 23)]] -void createTypeVector([[vk::ext_reference]] ext_type<0> typeInt, +void createTypeVector([[vk::ext_reference]] vk::ext_type<0> typeInt, [[vk::ext_literal]] int componentCount); //CHECK: %spirvIntrinsicType = OpTypeInt 32 0 //CHECK: %spirvIntrinsicType_0 = OpTypeVector %spirvIntrinsicType 4 -ext_type<0> foo1; -ext_type<1> foo2; +vk::ext_type<0> foo1; +vk::ext_type<1> foo2; float main() : SV_Target { createTypeInt(32, 0); From bed79332497d8e9b837a98e500acb181967334e4 Mon Sep 17 00:00:00 2001 From: Lu Jiao Date: Wed, 24 Nov 2021 11:17:41 +0800 Subject: [PATCH 3/6] Fix clang hlsl/spirv tests --- tools/clang/lib/Sema/SemaHLSL.cpp | 2 +- .../test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index a3387173cf..8d51ccf5c9 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -3595,7 +3595,7 @@ class HLSLExternalSource : public ExternalSemaSource { recordDecl = DeclareUIntTemplatedTypeWithHandle(*m_context, "FeedbackTexture2DArray", "kind"); } #ifdef ENABLE_SPIRV_CODEGEN - else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE) { + else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE && m_vkNSDecl) { recordDecl = DeclareUIntTemplatedTypeWithHandleInDeclContext( *m_context, m_vkNSDecl, typeName, "id"); recordDecl->setImplicit(true); diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl index f18ffa04c6..bd0715828d 100644 --- a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl @@ -12,16 +12,16 @@ void createAcceleStructureType(); [[vk::ext_instruction(/* OpRayQueryTerminateKHR */ 4474)]] void rayQueryTerminateEXT( - [[vk::ext_reference]] ext_type<2> rq); + [[vk::ext_reference]] vk::ext_type<2> rq); -ext_type<3> as : register(t0); +vk::ext_type<3> as : register(t0); [[vk::ext_instruction(/* OpRayQueryInitializeKHR */ 4473)]] -void rayQueryInitializeEXT([[vk::ext_reference]] ext_type<2> rayQuery, ext_type<3> as, uint rayFlags, uint cullMask, float3 origin, float tMin, float3 direction, float tMax); +void rayQueryInitializeEXT([[vk::ext_reference]] vk::ext_type<2> rayQuery, vk::ext_type<3> as, uint rayFlags, uint cullMask, float3 origin, float tMin, float3 direction, float tMax); [[vk::ext_instruction(/* OpRayQueryTerminateKHR */ 4474)]] void rayQueryTerminateEXT( - [[vk::ext_reference]] ext_type<2> rq ); + [[vk::ext_reference]] vk::ext_type<2> rq ); //CHECK: %spirvIntrinsicType = OpTypeAccelerationStructureKHR //CHECK: %spirvIntrinsicType_0 = OpTypeRayQueryKHR @@ -34,7 +34,7 @@ void main() { createTypeRayQueryKHR(); createAcceleStructureType(); - ext_type<2> rq; + vk::ext_type<2> rq; rayQueryInitializeEXT(rq, as, 0, 0, float3(0, 0, 0), 0.0, float3(1,1,1), 1.0); rayQueryTerminateEXT(rq); } From 23ba4cb154a7a08dac809b668147e09f6f56da51 Mon Sep 17 00:00:00 2001 From: Lu Jiao Date: Wed, 24 Nov 2021 17:24:09 +0800 Subject: [PATCH 4/6] Address review points --- .../clang/include/clang/SPIRV/SpirvContext.h | 3 +- tools/clang/include/clang/SPIRV/SpirvType.h | 27 +++++-- tools/clang/lib/SPIRV/EmitVisitor.cpp | 78 ++++++++++++------- tools/clang/lib/SPIRV/EmitVisitor.h | 4 + tools/clang/lib/SPIRV/SpirvContext.cpp | 9 +-- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 20 +++-- tools/clang/lib/SPIRV/SpirvType.cpp | 6 +- 7 files changed, 93 insertions(+), 54 deletions(-) diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index d3baef3c2b..1d728353cc 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -290,8 +290,7 @@ class SpirvContext { const SpirvIntrinsicType * getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, - llvm::ArrayRef constants, - SpirvIntrinsicType *elementTy); + llvm::ArrayRef operands); SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId); diff --git a/tools/clang/include/clang/SPIRV/SpirvType.h b/tools/clang/include/clang/SPIRV/SpirvType.h index 44a877517c..baf9c304ab 100644 --- a/tools/clang/include/clang/SPIRV/SpirvType.h +++ b/tools/clang/include/clang/SPIRV/SpirvType.h @@ -413,22 +413,35 @@ class RayQueryTypeKHR : public SpirvType { } }; -class SpirvConstant; +class SpirvInstruction; +struct SpvIntrinsicTypeOperand { + SpvIntrinsicTypeOperand(SpirvType *type_operand) + : operand_as_type(type_operand), isTypeOperand(true) {} + SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand) + : operand_as_inst(inst_operand), isTypeOperand(false) {} + union { + SpirvType *operand_as_type; + SpirvInstruction *operand_as_inst; + }; + bool isTypeOperand; +}; + class SpirvIntrinsicType : public SpirvType { public: - SpirvIntrinsicType(unsigned typeOp, llvm::ArrayRef constants, - SpirvIntrinsicType *elementTy); + SpirvIntrinsicType(unsigned typeOp, + llvm::ArrayRef inOps); + static bool classof(const SpirvType *t) { return t->getKind() == TK_SpirvIntrinsicType; } unsigned getOpCode() const { return typeOpCode; } - llvm::ArrayRef getLiterals() const { return literals; } - SpirvIntrinsicType *getElemType() const { return elementType; } + llvm::ArrayRef getOperands() const { + return operands; + } private: unsigned typeOpCode; - llvm::SmallVector literals; - SpirvIntrinsicType *elementType; + llvm::SmallVector operands; }; class HybridType : public SpirvType { diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 56459d539d..eece2db637 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -2450,15 +2450,22 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { initTypeInstruction(spv::Op::OpTypeRayQueryKHR); curTypeInst.push_back(id); finalizeTypeInstruction(); - } - else if (const auto *spvIntrinsicType = dyn_cast(type)) { + } else if (const auto *spvIntrinsicType = + dyn_cast(type)) { initTypeInstruction(static_cast(spvIntrinsicType->getOpCode())); curTypeInst.push_back(id); - if (spvIntrinsicType->getElemType()) { - curTypeInst.push_back(emitType(spvIntrinsicType->getElemType())); - } - for (auto& literal : spvIntrinsicType->getLiterals()) { - emitLiteral(literal, curTypeInst); + for (const SpvIntrinsicTypeOperand &operand : + spvIntrinsicType->getOperands()) { + if (operand.isTypeOperand) { + curTypeInst.push_back(emitType(operand.operand_as_type)); + } else { + auto *literal = dyn_cast(operand.operand_as_inst); + if (literal && literal->isLiteral()) { + emitLiteral(literal, curTypeInst); + } else { + curTypeInst.push_back(getOrAssignResultId(operand.operand_as_inst)); + } + } } finalizeTypeInstruction(); } @@ -2477,36 +2484,47 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { return id; } +template +void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral, + vecType &outInst) { + const auto &literalVal = intLiteral->getValue(); + bool positive = !literalVal.isNegative(); + if (literalVal.getBitWidth() <= 32) { + outInst.push_back(positive ? literalVal.getZExtValue() + : literalVal.getSExtValue()); + } else { + assert(literalVal.getBitWidth() == 64); + uint64_t val = + positive ? literalVal.getZExtValue() : literalVal.getSExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } +} + +template +void EmitTypeHandler::emitFloatLiteral(const SpirvConstantFloat *fLiteral, + vecType &outInst) { + const auto &literalVal = fLiteral->getValue(); + const auto bitwidth = llvm::APFloat::getSizeInBits(literalVal.getSemantics()); + if (bitwidth <= 32) { + outInst.push_back(literalVal.bitcastToAPInt().getZExtValue()); + } else { + assert(bitwidth == 64); + uint64_t val = literalVal.bitcastToAPInt().getZExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } +} + template void EmitTypeHandler::emitLiteral(const SpirvConstant *literal, VecType &outInst) { if (auto boolLiteral = dyn_cast(literal)) { outInst.push_back(static_cast(boolLiteral->getValue())); } else if (auto intLiteral = dyn_cast(literal)) { - const auto &literalVal = intLiteral->getValue(); - bool positive = !literalVal.isNegative(); - if (literalVal.getBitWidth() <= 32) { - outInst.push_back(positive ? literalVal.getZExtValue() - : literalVal.getSExtValue()); - } else { - assert(literalVal.getBitWidth() == 64); - uint64_t val = - positive ? literalVal.getZExtValue() : literalVal.getSExtValue(); - outInst.push_back(static_cast(val)); - outInst.push_back(static_cast(val >> 32)); - } + emitIntLiteral(intLiteral, outInst); } else if (auto fLiteral = dyn_cast(literal)) { - const auto &literalVal = fLiteral->getValue(); - const auto bitwidth = - llvm::APFloat::getSizeInBits(literalVal.getSemantics()); - if (bitwidth <= 32) { - outInst.push_back(literalVal.bitcastToAPInt().getZExtValue()); - } else { - assert(bitwidth == 64); - uint64_t val = literalVal.bitcastToAPInt().getZExtValue(); - outInst.push_back(static_cast(val)); - outInst.push_back(static_cast(val >> 32)); - } + emitFloatLiteral(fLiteral, outInst); } } diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index 38b2431fc1..f979247d69 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -111,6 +111,10 @@ class EmitTypeHandler { uint32_t getOrCreateConstantBool(SpirvConstantBoolean *); template void emitLiteral(const SpirvConstant *, vecType &outInst); + template + void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst); + template + void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst); private: void initTypeInstruction(spv::Op op); diff --git a/tools/clang/lib/SPIRV/SpirvContext.cpp b/tools/clang/lib/SPIRV/SpirvContext.cpp index bccb47dd43..3db12aac54 100644 --- a/tools/clang/lib/SPIRV/SpirvContext.cpp +++ b/tools/clang/lib/SPIRV/SpirvContext.cpp @@ -527,13 +527,12 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) { typeTemplateParams.clear(); } -const SpirvIntrinsicType * -SpirvContext::getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, - llvm::ArrayRef constants, - SpirvIntrinsicType *elementTy) { +const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType( + unsigned typeId, unsigned typeOpCode, + llvm::ArrayRef operands) { if (spirvIntrinsicTypes[typeId] == nullptr) { spirvIntrinsicTypes[typeId] = - new (this) SpirvIntrinsicType(typeOpCode, constants, elementTy); + new (this) SpirvIntrinsicType(typeOpCode, operands); } return spirvIntrinsicTypes[typeId]; } diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 666c217dbe..cae6c6a68f 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -12609,7 +12609,6 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { auto funcDecl = expr->getDirectCallee(); auto typeDefAttr = funcDecl->getAttr(); SpirvIntrinsicType *elementType = nullptr; - SmallVector constants; llvm::SmallVector capbilities; llvm::SmallVector extensions; @@ -12621,23 +12620,32 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { } } + SmallVector operands; const auto args = expr->getArgs(); for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { auto param = funcDecl->getParamDecl(i); const Expr *arg = args[i]->IgnoreParenLValueCasts(); if (param->hasAttr()) { - auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType()); - elementType = spvContext.getCreatedSpirvIntrinsicType(typeId); + auto *recType = param->getType()->getAs(); + if (recType && recType->getDecl()->getName() == "ext_type") { + auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType()); + auto *typeArg = spvContext.getCreatedSpirvIntrinsicType(typeId); + operands.emplace_back(typeArg); + } else { + operands.emplace_back(doExpr(arg)); + } } else if (param->hasAttr()) { SpirvInstruction *argInst = doExpr(arg); auto constArg = dyn_cast(argInst); assert(constArg != nullptr); constArg->setLiteral(); - constants.push_back(constArg); + operands.emplace_back(constArg); + } else { + operands.emplace_back(loadIfGLValue(arg)); } } - spvContext.getSpirvIntrinsicType( - typeDefAttr->getId(), typeDefAttr->getOpcode(), constants, elementType); + spvContext.getSpirvIntrinsicType(typeDefAttr->getId(), + typeDefAttr->getOpcode(), operands); // Emit dummy OpNop with no semantic meaning, with possible extension and // capabilities diff --git a/tools/clang/lib/SPIRV/SpirvType.cpp b/tools/clang/lib/SPIRV/SpirvType.cpp index 91ba944416..b191fcd20b 100644 --- a/tools/clang/lib/SPIRV/SpirvType.cpp +++ b/tools/clang/lib/SPIRV/SpirvType.cpp @@ -168,11 +168,9 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const { } SpirvIntrinsicType::SpirvIntrinsicType( - unsigned typeOp, llvm::ArrayRef constants, - SpirvIntrinsicType *eleTy) + unsigned typeOp, llvm::ArrayRef inOps) : SpirvType(TK_SpirvIntrinsicType, "spirvIntrinsicType"), - typeOpCode(typeOp), literals(constants.begin(), constants.end()), - elementType(eleTy) {} + typeOpCode(typeOp), operands(inOps.begin(), inOps.end()) {} StructType::StructType(llvm::ArrayRef fieldsVec, llvm::StringRef name, bool isReadOnly, From 13d1bc88be31ce9390ce389137fad1f4ba3721f3 Mon Sep 17 00:00:00 2001 From: Lu Jiao Date: Wed, 24 Nov 2021 18:23:35 +0800 Subject: [PATCH 5/6] Fix clang compiling --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index cae6c6a68f..3357e548b5 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -12608,7 +12608,6 @@ SpirvInstruction * SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { auto funcDecl = expr->getDirectCallee(); auto typeDefAttr = funcDecl->getAttr(); - SpirvIntrinsicType *elementType = nullptr; llvm::SmallVector capbilities; llvm::SmallVector extensions; @@ -12649,7 +12648,6 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { // Emit dummy OpNop with no semantic meaning, with possible extension and // capabilities - SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt( static_cast(spv::Op::OpNop), QualType(), {}, extensions, {}, capbilities, expr->getExprLoc()); From 6fda39c5505aa99696be717e13220fef1baececd Mon Sep 17 00:00:00 2001 From: Lu Jiao Date: Mon, 29 Nov 2021 16:57:37 +0800 Subject: [PATCH 6/6] Fix memory leakage of the spirvIntrinsicTypes --- tools/clang/lib/SPIRV/SpirvContext.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tools/clang/lib/SPIRV/SpirvContext.cpp b/tools/clang/lib/SPIRV/SpirvContext.cpp index 3db12aac54..9c180a3ef6 100644 --- a/tools/clang/lib/SPIRV/SpirvContext.cpp +++ b/tools/clang/lib/SPIRV/SpirvContext.cpp @@ -95,6 +95,11 @@ SpirvContext::~SpirvContext() { for (auto &typePair : typeTemplateParams) typePair.second->releaseMemory(); + + for (auto &pair : spirvIntrinsicTypes) { + assert(pair.second); + pair.second->~SpirvIntrinsicType(); + } } inline uint32_t log2ForBitwidth(uint32_t bitwidth) {