Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TOSA] add additional verification to TOSA #108133

Merged
merged 1 commit into from
Sep 11, 2024

Conversation

arteen1000
Copy link
Contributor


Motivation:

Spec conformance. Allows assumptions to be made in TOSA code.


Changes Made:

Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0.

----------
Motivation:
----------

Spec conformance. Allows assumptions to be made in TOSA
code.

------------
Changes Made:
------------

Add full permutation tensor verification to tosa.TRANSPOSE.
Priorly would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0
for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed
tensor: tensor<3x0xi32>. Naturally, this means that the number of elements
in a TOSA tensor will always be greater than 0.

Signed-off-by: Arteen Abrishami <arteen.abrishami@arm.com>
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Arteen Abrishami (arteen1000)

Changes

Motivation:

Spec conformance. Allows assumptions to be made in TOSA code.


Changes Made:

Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0.


Patch is 51.58 KiB, truncated to 20.00 KiB below, full version: /~https://github.com/llvm/llvm-project/pull/108133.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt (+2-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+27-31)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+43-18)
  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+13)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+42-22)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+75-60)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+5-5)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+63-11)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 12b4fc402c390f..1ee105f0ceb98b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
 add_mlir_interface(TosaInterfaces)
 
 set(LLVM_TARGET_DEFINITIONS TosaOps.td)
-mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
+mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
 add_public_tablegen_target(MLIRTosaAttributesIncGen)
 
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..63572f287b7dde 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
@@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
 
   let arguments = (ins
     Tosa_Tensor5D:$input,
-    TensorRankOf<[Tosa_Weight], [5]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
 
   let arguments = (ins
     Tosa_Tensor2D:$input,
-    2DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
     Tosa_Tensor1D:$bias,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
   );
@@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$filter,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttrUpto4:$out_shape,
@@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
   }];
 
   let arguments = (ins
-    I1Tensor:$input1
+    Tosa_I1Tensor:$input1
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 }
 
@@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   }];
 
   let arguments = (ins
-    I1Tensor:$pred,
+    Tosa_I1Tensor:$pred,
     Tosa_Tensor:$on_true,
     Tosa_Tensor:$on_false
   );
@@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let extraClassDeclaration = [{
@@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1721,7 +1716,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Int32Or64Tensor:$perms
+    Tosa_Int32Tensor:$perms
   );
 
   let results = (
@@ -1729,7 +1724,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   );
 
   let extraClassDeclaration = [{
-    LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
+    LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
   }];
 
   let hasCanonicalizer = 1;
@@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values,
-    2DTensorOf<[Tosa_Int32]>:$indices
+    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
   );
 
   let results = (outs
@@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values_in,
-    2DTensorOf<[Tosa_Int32]>:$indices,
+    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
     Tosa_Tensor3D:$input
   );
 
@@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+    TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   }];
 
   let arguments = (ins
-    I1Tensor:$cond,
+    Tosa_I1Tensor:$cond,
     Variadic<Tosa_Tensor>:$inputs
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..c3a0128e95a84b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
                              Tosa_QuantizedInt, AnyFloat]>;
 
+//===----------------------------------------------------------------------===//
+// TOSA Tensor Conformance
+//===----------------------------------------------------------------------===//
+
+def HasNo0Dimensions : And<[
+    IsRankedTensorTypePred,
+    CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
+
+class TosaTensorOf<
+    list<Type> allowedTypes, string summary = "tosa-conformant tensor">
+    : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
+
+class TosaRankedTensorOf<
+    list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
+    : RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
+
+class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
+    : UnrankedTensorOf<allowedTypes, preds, summary>;
+
+class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
+    : TosaRankedTensorOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
+
 //===----------------------------------------------------------------------===//
 // Tensor types
 //===----------------------------------------------------------------------===//
 
-def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
-def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
+def Tosa_I1Tensor : TosaTensorOf<[I1]>;
+def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
+def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
 
-def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
+def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
-def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
 
 // Must be ranked but no further constraints
-def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
                                 AnyFloat.predicate]>, "tosa.dtype">;
 
 class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
-  AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
+  AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
 
 //===----------------------------------------------------------------------===//
 // Tensor types with constrained ranks.
 //===----------------------------------------------------------------------===//
 
 // Rank-0 (scalar) tensor
-def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
+def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 
 // We include unranked tensors as a supported type for all possible tosa
 // Tensors as unranked does not guarantee invalid. If unranked tensors exist
 // they should be shape propagate used Tosa's shape inference pass and verified
 // to not include any remaining unranked tensors.
-def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
 
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
 def Tosa_Tensor1Dto6D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
 
 def Tosa_TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
 def Tosa_Int32TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
 
 //===----------------------------------------------------------------------===//
 // Generic scalar, vector, or tensor of a particular type.
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
 class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
      AnyTypeOf<types>.predicate,
      VectorOf<types>.predicate,
-     TensorOf<types>.predicate]>,
+     TosaTensorOf<types>.predicate]>,
      description>;
 
 def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ef40b348ab5499..90fea1f68beb58 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
   return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
 }
 
+// Apply an int32_t permutation to some input, that should be of the same
+// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
+template <typename T>
+SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
+                                    ArrayRef<int32_t> perms) {
+  SmallVector<T> permuted;
+  size_t N = input.size();
+  permuted.resize_for_overwrite(N);
+  for (size_t i = 0; i < N; i++)
+    permuted[i] = input[perms[i]];
+  return permuted;
+}
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..fe53b499674324 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -313,7 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         // convolution operation.
         // TODO(suderman): See if this can be efficiently folded - check whether
         // the input is used anywhere else, if not fold the constant.
-        SmallVector<int64_t> weightPerm;
+        SmallVector<int32_t> weightPerm;
         for (int i = 1; i < resultTy.getRank(); i++)
           weightPerm.push_back(i);
         weightPerm.push_back(0);
@@ -321,7 +321,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         SmallVector<int64_t> newWeightShape;
         for (auto dim : weightPerm)
           newWeightShape.push_back(weightShape[dim]);
-        auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+        auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
         Value weightPermValue =
             rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
         Type newWeightTy =
@@ -337,7 +337,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     if (5 == inputTy.getRank()) {
       // TODO(suderman): See if this can be efficiently folded - check whether
       // the input is used anywhere else, if not fold the constant.
-      SmallVector<int64_t> weightPerm;
+      SmallVector<int32_t> weightPerm;
       for (int i = 1; i < resultTy.getRank(); i++)
         weightPerm.push_back(i);
       weightPerm.push_back(0);
@@ -345,7 +345,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       SmallVector<int64_t> newWeightShape;
       for (auto dim : weightPerm)
         newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+      auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
       Value weightPermValue =
           rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
       Type newWeightTy =
@@ -1040,22 +1040,25 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const final {
-    SmallVector<int64_t> constantPerms;
+    SmallVector<int32_t> constantPerms;
     if (failed(op.getConstantPerms(constantPerms)))
       return failure();
 
     Location loc = op.getLoc();
-    // The verifier should have made sure we have a valid permutation tensor.
-    assert(isPermutationVector(constantPerms) && "Expected valid permutation");
+    // The verifier should have made sure we have a valid TOSA permutation
+    // tensor. isPermutationVector doesn't actually check the TOSA perms we
+    // expect.
     SmallVector<OpFoldResult> inputSizes =
         tensor::getMixedSizes(rewriter, loc, op.getInput1());
     auto permutedSizes =
-        applyPermutation<OpFoldResult>(inputSizes, constantPerms);
+        applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
 
     auto permutedInit = rewriter.create<tensor::EmptyOp>(
         loc, permutedSizes, op.getInput1().getType().getElementType());
     rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
-        op, op.getInput1(), permutedInit, constantPerms);
+        op, op.getInput1(), permutedInit,
+        llvm::to_vector(llvm::map_range(
+            constantPerms, [](int32_t v) -> int64_t { return v; })));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index da9a93feac4d65..03876a7c64d07c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,7 +88,7 @@ struct ConsolidateTransposeOptimization
       return rewriter.notifyMatchFailure(transposeOp,
                                          "input must be transpose operation");
 
-    SmallVector<int64_t> transposePerms, innerTransposePerms;
+    SmallVector<int32_t> transposePerms, innerTransposePerms;
     if (transposeOp.getConstantPerms(transposePerms).failed())
       return rewriter.notifyMatchFailure(transposeOp,
                                          "transpose perms must be constant");
@@ -497,8 +497,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
@@ -536,8 +538,10 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
 
   // IntDivOp inputs must be integer type, no need to check for quantized type
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
   if (lhsAttr && lhsAttr.isSplat()) {
     if (llvm::isa<IntegerType>(resultETy) &&
         lhsAttr.getSplatValue<APInt>().isZero())
@@ -605,10 +609,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+
   if (rhsTy == resultTy) {
     if (isSplatZero(resultETy, lhsAttr))
       return lhsAttr.resizeSplat(resultTy);
@@ -638,8 +645,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy....
[truncated]

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, it should catch some cases that are currently slipping through the cracks.

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great patch.

@GeorgeARM GeorgeARM merged commit a54efdb into llvm:main Sep 11, 2024
12 checks passed
Copy link

@arteen1000 Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@llvm-ci
Copy link
Collaborator

llvm-ci commented Sep 11, 2024

LLVM Buildbot has detected a new failure on builder clang-aarch64-sve-vla-2stage running on linaro-g3-03 while building mlir at step 12 "ninja check 2".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/41/builds/1969

Here is the relevant piece of the build log for the reference
Step 12 (ninja check 2) failure: stage 2 checked (failure)
******************** TEST 'AddressSanitizer-aarch64-linux :: TestCases/Posix/halt_on_error-signals.c' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
RUN: at line 3: /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/./bin/clang  -fsanitize=address -mno-omit-leaf-frame-pointer -fno-omit-frame-pointer -fno-optimize-sibling-calls -gline-tables-only   -Wthread-safety -Wthread-safety-reference -Wthread-safety-beta   -fsanitize-recover=address -pthread /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c -o /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/runtimes/runtimes-bins/compiler-rt/test/asan/AARCH64LinuxConfig/TestCases/Posix/Output/halt_on_error-signals.c.tmp
+ /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/./bin/clang -fsanitize=address -mno-omit-leaf-frame-pointer -fno-omit-frame-pointer -fno-optimize-sibling-calls -gline-tables-only -Wthread-safety -Wthread-safety-reference -Wthread-safety-beta -fsanitize-recover=address -pthread /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c -o /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/runtimes/runtimes-bins/compiler-rt/test/asan/AARCH64LinuxConfig/TestCases/Posix/Output/halt_on_error-signals.c.tmp
In file included from /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c:11:
In file included from /usr/include/stdio.h:27:
In file included from /usr/include/aarch64-linux-gnu/bits/libc-header-start.h:33:
/usr/include/features.h:194:3: warning: "_BSD_SOURCE and _SVID_SOURCE are deprecated, use _DEFAULT_SOURCE" [-W#warnings]
  194 | # warning "_BSD_SOURCE and _SVID_SOURCE are deprecated, use _DEFAULT_SOURCE"
      |   ^
1 warning generated.
RUN: at line 5: env ASAN_OPTIONS=halt_on_error=false:suppress_equal_pcs=false  /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/runtimes/runtimes-bins/compiler-rt/test/asan/AARCH64LinuxConfig/TestCases/Posix/Output/halt_on_error-signals.c.tmp 100 >/home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/runtimes/runtimes-bins/compiler-rt/test/asan/AARCH64LinuxConfig/TestCases/Posix/Output/halt_on_error-signals.c.tmp.log 2>&1 || true
+ env ASAN_OPTIONS=halt_on_error=false:suppress_equal_pcs=false /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/runtimes/runtimes-bins/compiler-rt/test/asan/AARCH64LinuxConfig/TestCases/Posix/Output/halt_on_error-signals.c.tmp 100
+ true
RUN: at line 7: FileCheck --check-prefix=CHECK-COLLISION /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c </home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/runtimes/runtimes-bins/compiler-rt/test/asan/AARCH64LinuxConfig/TestCases/Posix/Output/halt_on_error-signals.c.tmp.log || FileCheck --check-prefix=CHECK-NO-COLLISION /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c </home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/runtimes/runtimes-bins/compiler-rt/test/asan/AARCH64LinuxConfig/TestCases/Posix/Output/halt_on_error-signals.c.tmp.log
+ FileCheck --check-prefix=CHECK-COLLISION /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c
/home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c:29:22: error: CHECK-COLLISION: expected string not found in input
 // CHECK-COLLISION: AddressSanitizer: nested bug in the same thread, aborting
                     ^
<stdin>:1:1: note: scanning from here
=================================================================
^
<stdin>:55:85: note: possible intended match here
/home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/stage2/./bin/llvm-symbolizerAddressSanitizer: : nested bug in the same thread, aborting.
                                                                                    ^

Input file: <stdin>
Check file: /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c

-dump-input=help explains the following input dump.

Input was:
<<<<<<
            1: ================================================================= 
check:29'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
            2: ==1249730==ERROR: AddressSanitizer: use-after-poison on address 0xaaaae7fc1bc0 at pc 0xaaaae7f71cec bp 0xfbff875fe640 sp 0xfbff875fe638 
check:29'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            3: WRITE of size 1 at 0xaaaae7fc1bc0 thread T1 
check:29'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            4:  #0 0xaaaae7f71ce8 in error /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c:32:12 
check:29'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            5:  #1 0xaaaae7f72098 in receiver /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/test/asan/TestCases/Posix/halt_on_error-signals.c:66:5 
check:29'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            6:  #2 0xaaaae7f31e3c in asan_thread_start(void*) /home/tcwg-buildbot/worker/clang-aarch64-sve-vla-2stage/llvm/compiler-rt/lib/asan/asan_interceptors.cpp:239:28 
check:29'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            .
            .
...

VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Sep 12, 2024
----------
Motivation:
----------

Spec conformance. Allows assumptions to be made in TOSA code.

------------
Changes Made:
------------

Add full permutation tensor verification to tosa.TRANSPOSE. Priorly
would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0
for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a
disallowed tensor: tensor<3x0xi32>. Naturally, this means that the
number of elements in a TOSA tensor will always be greater than 0.

Signed-off-by: Arteen Abrishami <arteen.abrishami@arm.com>
@arteen1000 arteen1000 deleted the tosa-additional-verification branch September 14, 2024 06:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants