From 6b53fe69e60b555e40c77a5c2b8529a55a0db3e2 Mon Sep 17 00:00:00 2001 From: Sven van Haastregt Date: Sun, 4 Sep 2022 11:51:12 +0100 Subject: [PATCH] Dot product bugfix to include more floating point types (#1578) (#1610) Switched the visitCallDot check to use isFloatingPointTy for scalar floating point operands. Bugfix for previous change regarding integer dot product. (cherry picked from commit 71e01b53a8ac7b0c15f4bb3cd73bcefdcc137954) Co-authored-by: Jakub Czarnecki --- lib/SPIRV/OCLToSPIRV.cpp | 3 +- .../dot_product_OCLtoSPIRV_half.ll | 41 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 test/transcoding/dot_product_OCLtoSPIRV_half.ll diff --git a/lib/SPIRV/OCLToSPIRV.cpp b/lib/SPIRV/OCLToSPIRV.cpp index 866c487972..95e52cc838 100644 --- a/lib/SPIRV/OCLToSPIRV.cpp +++ b/lib/SPIRV/OCLToSPIRV.cpp @@ -323,8 +323,7 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) { return; } if (DemangledName == kOCLBuiltinName::Dot && - (CI.getOperand(0)->getType()->isFloatTy() || - CI.getOperand(1)->getType()->isDoubleTy())) { + CI.getOperand(0)->getType()->isFloatingPointTy()) { visitCallDot(&CI); return; } diff --git a/test/transcoding/dot_product_OCLtoSPIRV_half.ll b/test/transcoding/dot_product_OCLtoSPIRV_half.ll new file mode 100644 index 0000000000..dc188851e9 --- /dev/null +++ b/test/transcoding/dot_product_OCLtoSPIRV_half.ll @@ -0,0 +1,41 @@ +; RUN: llvm-as < %s -o %t.bc +; RUN: llvm-spirv -s %t.bc -o %t.regularized.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_integer_dot_product -o %t-spirv.spv +; RUN: spirv-val %t-spirv.spv +; RUN: llvm-dis %t.regularized.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM +; RUN: llvm-spirv %t.bc -spirv-text --spirv-ext=+SPV_KHR_integer_dot_product -o - | FileCheck %s --check-prefix=CHECK-SPIRV + +;CHECK-LLVM: fmul + +;CHECK-SPIRV: FMul + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir" + +; Function Attrs: convergent norecurse nounwind +define spir_kernel void @test1(half %ha, half %hb) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 { +entry: + %call = tail call spir_func half @_Z3dotDhDh(half %ha, half %hb) #2 + ret void +} + +; Function Attrs: convergent +declare spir_func half @_Z3dotDhDh(half, half) local_unnamed_addr #1 + +attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { convergent nounwind } + +!llvm.module.flags = !{!0} +!opencl.ocl.version = !{!1} +!opencl.spir.version = !{!1} +!llvm.ident = !{!2} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 2, i32 0} +!2 = !{!"clang version 11.0.0 (/~https://github.com/c199914007/llvm.git 8b94769313ca84cb9370b60ed008501ff692cb71)"} +!3 = !{i32 0, i32 0} +!4 = !{!"none", !"none"} +!5 = !{!"half", !"half"} +!6 = !{!"half", !"half"} +!7 = !{!"", !""}