diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt index 12b4fc402c390f9..1ee105f0ceb98bf 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc) add_mlir_interface(TosaInterfaces) set(LLVM_TARGET_DEFINITIONS TosaOps.td) -mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls) -mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs) +mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa) +mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa) add_public_tablegen_target(MLIRTosaAttributesIncGen) set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index ab6daa39708d13c..63572f287b7ddec 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { let arguments = (ins Tosa_Tensor4D:$input, - Tosa_IntArrayAttr2:$kernel, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$pad, @@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - 4DTensorOf<[Tosa_Weight]>:$weight, + TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, @@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { let arguments = (ins Tosa_Tensor5D:$input, - TensorRankOf<[Tosa_Weight], [5]>:$weight, + TosaTensorRankOf<[Tosa_Weight], [5]>:$weight, Tosa_Tensor1D:$bias, - Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, Tosa_IntArrayAttr3:$dilation, @@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - 4DTensorOf<[Tosa_Weight]>:$weight, + TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, @@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> { let arguments = (ins Tosa_Tensor2D:$input, - 2DTensorOf<[Tosa_Weight]>:$weight, + TosaTensorRankOf<[Tosa_Weight], [2]>:$weight, Tosa_Tensor1D:$bias, OptionalAttr:$quantization_info ); @@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - 4DTensorOf<[Tosa_Weight]>:$filter, + TosaTensorRankOf<[Tosa_Weight], [4]>:$filter, Tosa_Tensor1D:$bias, - Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttrUpto4:$out_shape, @@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [ }]; let arguments = (ins - I1Tensor:$input1, - I1Tensor:$input2 + Tosa_I1Tensor:$input1, + Tosa_I1Tensor:$input2 ); let results = (outs - I1Tensor:$z + Tosa_I1Tensor:$z ); } @@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [ }]; let arguments = (ins - I1Tensor:$input1, - I1Tensor:$input2 + Tosa_I1Tensor:$input1, + Tosa_I1Tensor:$input2 ); let results = (outs - I1Tensor:$z + Tosa_I1Tensor:$z ); } @@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [ }]; let arguments = (ins - I1Tensor:$input1, - I1Tensor:$input2 + Tosa_I1Tensor:$input1, + Tosa_I1Tensor:$input2 ); let results = (outs - I1Tensor:$z + Tosa_I1Tensor:$z ); } @@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not", }]; let arguments = (ins - I1Tensor:$input1 + Tosa_I1Tensor:$input1 ); let results = (outs - I1Tensor:$output + Tosa_I1Tensor:$output ); } @@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { }]; let arguments = (ins - I1Tensor:$pred, + Tosa_I1Tensor:$pred, Tosa_Tensor:$on_true, Tosa_Tensor:$on_false ); @@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [ ); let results = (outs - I1Tensor:$output + Tosa_I1Tensor:$output ); let extraClassDeclaration = [{ @@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> { ); let results = (outs - I1Tensor:$output + Tosa_I1Tensor:$output ); let hasFolder = 1; @@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", ); let results = (outs - I1Tensor:$output + Tosa_I1Tensor:$output ); let hasFolder = 1; @@ -1721,7 +1716,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", let arguments = (ins Tosa_Tensor:$input1, - Tosa_Int32Or64Tensor:$perms + Tosa_Int32Tensor:$perms ); let results = ( @@ -1729,7 +1724,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", ); let extraClassDeclaration = [{ - LogicalResult getConstantPerms(llvm::SmallVector &perms); + LogicalResult getConstantPerms(llvm::SmallVector &perms); }]; let hasCanonicalizer = 1; @@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> { let arguments = (ins Tosa_Tensor3D:$values, - 2DTensorOf<[Tosa_Int32]>:$indices + TosaTensorRankOf<[Tosa_Int32], [2]>:$indices ); let results = (outs @@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> { let arguments = (ins Tosa_Tensor3D:$values_in, - 2DTensorOf<[Tosa_Int32]>:$indices, + TosaTensorRankOf<[Tosa_Int32], [2]>:$indices, Tosa_Tensor3D:$input ); @@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure, ); let results = (outs - TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output + TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output ); let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if", }]; let arguments = (ins - I1Tensor:$cond, + Tosa_I1Tensor:$cond, Variadic:$inputs ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 14fc9c7a6730cc2..c3a0128e95a84bb 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8, Tosa_QuantizedInt, AnyFloat]>; +//===----------------------------------------------------------------------===// +// TOSA Tensor Conformance +//===----------------------------------------------------------------------===// + +def HasNo0Dimensions : And<[ + IsRankedTensorTypePred, + CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>; + +class TosaTensorOf< + list allowedTypes, string summary = "tosa-conformant tensor"> + : TensorOf], summary>; + +class TosaRankedTensorOf< + list allowedTypes, list preds = [], string summary = "tosa-conformant ranked tensor"> + : RankedTensorOf; + +class TosaUnrankedTensorOf allowedTypes, list preds = [], string summary = "tosa-conformant unranked tensor"> + : UnrankedTensorOf; + +class TosaTensorRankOf allowedTypes, list ranks> + : TosaRankedTensorOf], + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; + //===----------------------------------------------------------------------===// // Tensor types //===----------------------------------------------------------------------===// -def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>; -def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; +def Tosa_I1Tensor : TosaTensorOf<[I1]>; +def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>; +def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>; -def Tosa_FloatTensor : TensorOf<[AnyFloat]>; +def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>; // Either ranked or unranked tensor of TOSA supported element types. -def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>; // Must be ranked but no further constraints -def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>; +def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>; // Any tensor element type allowed in Tosa ops. def Tosa_ElementType : Type, "tosa.dtype">; class Tosa_TensorOfOrNone allowedTypes, string description = ""> : - AnyTypeOf<[TensorOf, NoneType], description>; + AnyTypeOf<[TosaTensorOf, NoneType], description>; //===----------------------------------------------------------------------===// // Tensor types with constrained ranks. //===----------------------------------------------------------------------===// // Rank-0 (scalar) tensor -def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>; +def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; // We include unranked tensors as a supported type for all possible tosa // Tensors as unranked does not guarantee invalid. If unranked tensors exist // they should be shape propagate used Tosa's shape inference pass and verified // to not include any remaining unranked tensors. -def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>; +def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>; -def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">; -def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">; -def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">; -def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">; -def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">; +def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">; +def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">; +def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">; +def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">; +def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">; // Ranked tensors up to given rank. def Tosa_Tensor1Dto4D : AnyTypeOf<[ - Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>; + Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>; def Tosa_Tensor1Dto6D : AnyTypeOf<[ - Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>; + Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>; def Tosa_TensorUpto4D : AnyTypeOf<[ - Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>; + Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>; def Tosa_Int32TensorUpto4D : AnyTypeOf<[ - Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>; + Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>; //===----------------------------------------------------------------------===// // Generic scalar, vector, or tensor of a particular type. @@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[ class Tosa_TypeLike types, string description = ""> : TypeConstraint.predicate, VectorOf.predicate, - TensorOf.predicate]>, + TosaTensorOf.predicate]>, description>; def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index ef40b348ab54996..90fea1f68beb582 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc, return CreateOpAndInferShape(builder, resultTy, args...); } +// Apply an int32_t permutation to some input, that should be of the same +// size as perms. Perms should contain some permutation of 0 - perms.size() - 1. +template +SmallVector applyTOSAPermutation(ArrayRef input, + ArrayRef perms) { + SmallVector permuted; + size_t N = input.size(); + permuted.resize_for_overwrite(N); + for (size_t i = 0; i < N; i++) + permuted[i] = input[perms[i]]; + return permuted; +} + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 77c3d2e87579102..fe53b499674324c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -313,7 +313,7 @@ class ConvConverter : public OpConversionPattern { // convolution operation. // TODO(suderman): See if this can be efficiently folded - check whether // the input is used anywhere else, if not fold the constant. - SmallVector weightPerm; + SmallVector weightPerm; for (int i = 1; i < resultTy.getRank(); i++) weightPerm.push_back(i); weightPerm.push_back(0); @@ -321,7 +321,7 @@ class ConvConverter : public OpConversionPattern { SmallVector newWeightShape; for (auto dim : weightPerm) newWeightShape.push_back(weightShape[dim]); - auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); + auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm); Value weightPermValue = rewriter.create(loc, weightPermAttr); Type newWeightTy = @@ -337,7 +337,7 @@ class ConvConverter : public OpConversionPattern { if (5 == inputTy.getRank()) { // TODO(suderman): See if this can be efficiently folded - check whether // the input is used anywhere else, if not fold the constant. - SmallVector weightPerm; + SmallVector weightPerm; for (int i = 1; i < resultTy.getRank(); i++) weightPerm.push_back(i); weightPerm.push_back(0); @@ -345,7 +345,7 @@ class ConvConverter : public OpConversionPattern { SmallVector newWeightShape; for (auto dim : weightPerm) newWeightShape.push_back(weightShape[dim]); - auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); + auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm); Value weightPermValue = rewriter.create(loc, weightPermAttr); Type newWeightTy = @@ -1040,22 +1040,25 @@ class TransposeConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const final { - SmallVector constantPerms; + SmallVector constantPerms; if (failed(op.getConstantPerms(constantPerms))) return failure(); Location loc = op.getLoc(); - // The verifier should have made sure we have a valid permutation tensor. - assert(isPermutationVector(constantPerms) && "Expected valid permutation"); + // The verifier should have made sure we have a valid TOSA permutation + // tensor. isPermutationVector doesn't actually check the TOSA perms we + // expect. SmallVector inputSizes = tensor::getMixedSizes(rewriter, loc, op.getInput1()); auto permutedSizes = - applyPermutation(inputSizes, constantPerms); + applyTOSAPermutation(inputSizes, constantPerms); auto permutedInit = rewriter.create( loc, permutedSizes, op.getInput1().getType().getElementType()); rewriter.replaceOpWithNewOp( - op, op.getInput1(), permutedInit, constantPerms); + op, op.getInput1(), permutedInit, + llvm::to_vector(llvm::map_range( + constantPerms, [](int32_t v) -> int64_t { return v; }))); return success(); } }; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index da9a93feac4d65a..03876a7c64d07c3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -88,7 +88,7 @@ struct ConsolidateTransposeOptimization return rewriter.notifyMatchFailure(transposeOp, "input must be transpose operation"); - SmallVector transposePerms, innerTransposePerms; + SmallVector transposePerms, innerTransposePerms; if (transposeOp.getConstantPerms(transposePerms).failed()) return rewriter.notifyMatchFailure(transposeOp, "transpose perms must be constant"); @@ -497,8 +497,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) return getInput1(); @@ -536,8 +538,10 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { // IntDivOp inputs must be integer type, no need to check for quantized type auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsAttr && lhsAttr.isSplat()) { if (llvm::isa(resultETy) && lhsAttr.getSplatValue().isZero()) @@ -605,10 +609,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); const int64_t shift = llvm::isa(resultETy) ? getShift() : 0; + if (rhsTy == resultTy) { if (isSplatZero(resultETy, lhsAttr)) return lhsAttr.resizeSplat(resultTy); @@ -638,8 +645,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) { return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) return getInput1(); @@ -681,8 +690,10 @@ struct APIntFoldGreaterEqual { OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (!lhsAttr || !rhsAttr) return {}; @@ -693,8 +704,10 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (!lhsAttr || !rhsAttr) return {}; @@ -706,8 +719,10 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); Value lhs = getInput1(); Value rhs = getInput2(); auto lhsTy = llvm::cast(lhs.getType()); @@ -838,14 +853,16 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { return {}; // reshape(const(x)) -> const(reshape-attr(x)) - if (auto operand = llvm::dyn_cast_if_present(adaptor.getInput1())) { + if (auto operand = + llvm::dyn_cast_if_present(adaptor.getInput1())) { // Constants must have static shape. if (!outputTy.hasStaticShape()) return {}; // Okay to duplicate splat constants. if (operand.isSplat()) - return SplatElementsAttr::get(outputTy, operand.getSplatValue()); + return SplatElementsAttr::get(outputTy, + operand.getSplatValue()); // Don't duplicate other constants. if (!getInput1().hasOneUse()) @@ -905,7 +922,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput(); auto operandTy = llvm::cast(operand.getType()); auto axis = getAxis(); - auto operandAttr = llvm::dyn_cast_if_present(adaptor.getInput()); + auto operandAttr = + llvm::dyn_cast_if_present(adaptor.getInput()); if (operandAttr) return operandAttr; @@ -954,7 +972,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { if (getOnTrue() == getOnFalse()) return getOnTrue(); - auto predicate = llvm::dyn_cast_if_present(adaptor.getPred()); + auto predicate = + llvm::dyn_cast_if_present(adaptor.getPred()); if (!predicate) return {}; @@ -975,7 +994,8 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::cast(getType()); // Transposing splat values just means reshaping. - if (auto input = llvm::dyn_cast_if_present(adaptor.getInput1())) { + if (auto input = + llvm::dyn_cast_if_present(adaptor.getInput1())) { if (input.isSplat() && resultTy.hasStaticShape() && input.getType().getElementType() == resultTy.getElementType()) return input.reshape(resultTy); @@ -986,11 +1006,11 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { return {}; // Transpose is not the identity transpose. - SmallVector perms; + SmallVector perms; if (getConstantPerms(perms).failed()) return {}; - if (!llvm::equal(llvm::seq(0, perms.size()), perms)) + if (!llvm::equal(llvm::seq(0, perms.size()), perms)) return {}; return getInput1(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index d93db1b237f3164..0d0241fea5152ce 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -204,22 +204,6 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -static bool hasZeroDimension(ShapedType shapedType) { - if (!shapedType.hasRank()) - return false; - - auto rank = shapedType.getRank(); - - for (int i = 0; i < rank; i++) { - if (shapedType.isDynamicDim(i)) - continue; - if (shapedType.getDimSize(i) == 0) - return true; - } - - return false; -} - template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). @@ -236,10 +220,6 @@ static LogicalResult verifyConvOp(T op) { return failure(); } - if (hasZeroDimension(inputType)) - return op.emitOpError() << "tensor has a dimension with size zero. Each " - "dimension of a tensor must have size >= 1"; - auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); @@ -262,6 +242,29 @@ static LogicalResult verifyConvOp(T op) { "allowed for float type"); return failure(); } + return success(); +} + +LogicalResult tosa::ConstOp::verify() { + + auto attrType = llvm::dyn_cast(getValueAttr().getType()); + auto outputType = llvm::dyn_cast(getOutput().getType()); + + if (!attrType || !outputType) { + emitOpError("expected tensors for attr/result type"); + return failure(); + } + + if (auto result = llvm::dyn_cast( + outputType.getElementType())) { + if (result.getStorageType() == attrType.getElementType()) + return success(); + } + + if (attrType.getElementType() != outputType.getElementType()) { + emitOpError("expected same attr/result element types"); + return failure(); + } return success(); } @@ -283,9 +286,6 @@ LogicalResult tosa::ArgMaxOp::verify() { LogicalResult tosa::AvgPool2dOp::verify() { auto inputType = llvm::cast(getInput().getType()); - if (hasZeroDimension(inputType)) - return emitOpError() << "tensor has a dimension with size zero. Each " - "dimension of a tensor must have size >= 1"; auto inputETy = inputType.getElementType(); auto resultETy = llvm::cast(getType()).getElementType(); @@ -341,9 +341,9 @@ LogicalResult tosa::ClampOp::verify() { if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); - // if input datatype is float, check that the two min/max_fp attributes share - // the same type and that their type is either the same of the input's - // datatype, or a float type whose bitwidth > input datatype bitwidth + // If input datatype is float, check that the two min/max_fp attributes + // share the same type and that their type is either the same of the input's + // datatype, or a float type whose bitwidth > input datatype bitwidth. if (!inputETy.isInteger(dataTypeBitWidth)) { if (((maxFpType != minFpType) || (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <= @@ -383,7 +383,8 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, } } -/// Handles tosa.transpose_conv2d which has outpad and output shape attributes. +/// Handles tosa.transpose_conv2d which has outpad and output shape +/// attributes. static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, @@ -420,9 +421,9 @@ static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, } } -/// The tosa.matmul op is also intended to be generated where a fully_connected -/// op must be constructed where the weight is not a constant. In this case, -/// the fully_connected op must be expressed using matmul. +/// The tosa.matmul op is also intended to be generated where a +/// fully_connected op must be constructed where the weight is not a constant. +/// In this case, the fully_connected op must be expressed using matmul. /// TODO: Add link to the leglization document explaining this. static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, @@ -457,9 +458,9 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder, } } -/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr -/// but avg_pool operator has its own builder as it has additional parameters -/// not part of the unary ops. +/// Both the tosa.avg_pool2d and unary ops use the same +/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it +/// has additional parameters not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, @@ -526,8 +527,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); if (!shape.hasRank()) { - // TODO(jennik): Update function to have better case handling for invalid - // operands and for ranked tensors. + // TODO(jennik): Update function to have better case handling for + // invalid operands and for ranked tensors. return failure(); } outRank = std::max(outRank, shape.getRank()); @@ -776,8 +777,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( return success(); } - // If the input rank is unknown we can info the output rank using the padding - // shape's first dim. + // If the input rank is unknown we can info the output rank using the + // padding shape's first dim. if (!inputShape.hasRank()) { if (paddingShape.isDynamicDim(0)) { inferredReturnShapes.push_back(ShapedTypeComponents()); @@ -1000,10 +1001,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { TensorType inputType = getInput1().getType(); RankedTensorType outputType = getType(); - if (hasZeroDimension(inputType) || hasZeroDimension(outputType)) - return emitOpError() << "tensor has a dimension with size zero. Each " - "dimension of a tensor must have size >= 1"; - if ((int64_t)getNewShape().size() != outputType.getRank()) return emitOpError() << "new shape does not match result rank"; @@ -1034,16 +1031,15 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { return mlir::success(); } -LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector &perms) { +LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector &perms) { // Perms must be constants. DenseIntElementsAttr permsAttr; if (!matchPattern(getPerms(), m_Constant(&permsAttr))) return failure(); - // Transpose is not the identity transpose. - perms = llvm::to_vector( - llvm::map_range(permsAttr.getValues(), - [](const APInt &val) { return val.getSExtValue(); })); + perms.clear(); + for (auto v : permsAttr.getValues()) + perms.push_back(v.getSExtValue()); return success(); } @@ -1067,8 +1063,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( return success(); } - // This would imply the number of permutations does not match the rank of the - // input which is illegal. + // This would imply the number of permutations does not match the rank of + // the input which is illegal. if (permsShape.getDimSize(0) != inputShape.getRank()) { return failure(); } @@ -1154,19 +1150,38 @@ LogicalResult tosa::TransposeOp::verify() { << " (output rank) but got size " << permType.getDimSize(0); - SmallVector constantPerms; + SmallVector constantPerms; if (succeeded(getConstantPerms(constantPerms))) { - // Assert that the permutation tensor has a rank, which means that the rank - // has been verified above. + // Assert that the permutation tensor has a rank, which means that the + // rank has been verified above. assert(permType.hasRank() && "Unexpectedly found permutation tensor without rank"); - if (!isPermutationVector(constantPerms)) + if (!llvm::all_of(constantPerms, + [&constantPerms](int32_t s) { + return s >= 0 && + static_cast(s) < constantPerms.size(); + }) || + !isPermutationVector(llvm::to_vector(llvm::map_range( + constantPerms, [](int32_t v) -> int64_t { return v; })))) return emitOpError() << "expected valid permutation tensor"; - if (inputType.hasRank() && !llvm::all_of(constantPerms, [&](int64_t s) { - return s < inputType.getRank(); - })) { - return emitOpError() << "permutation must be within input bounds"; + // Verify that the types of the input and output tensors are properly + // permuted. + if (inputType.hasRank() && outputType.hasRank()) { + assert(constantPerms.size() == static_cast(inputType.getRank()) && + inputType.getRank() == outputType.getRank()); + + for (auto i = 0; i < outputType.getRank(); i++) { + if (inputType.isDynamicDim(constantPerms[i]) || + outputType.isDynamicDim(i)) + continue; + + if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i)) + return emitOpError() + << "expected output tensor dim " << i << " to match " + << "input dim " << constantPerms[i] << " with value of " + << inputType.getDimSize(constantPerms[i]); + } } } return success(); @@ -1175,7 +1190,7 @@ LogicalResult tosa::TransposeOp::verify() { LogicalResult TransposeOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - SmallVector transposePerms; + SmallVector transposePerms; if (getConstantPerms(transposePerms).failed()) return failure(); @@ -1184,7 +1199,7 @@ LogicalResult TransposeOp::reifyResultShapes( SmallVector returnedDims(inputType.getRank()); for (auto dim : transposePerms) { - int64_t dimInInput = transposePerms[dim]; + int32_t dimInInput = transposePerms[dim]; if (inputType.isDynamicDim(dimInInput)) returnedDims[dim] = builder.create(getLoc(), input, dimInInput) @@ -1378,8 +1393,8 @@ static LogicalResult verifyReduceOp(T op) { << ")"; return failure(); } - // We can only verify the reduced dimension size to be 1 if this is not the - // special case of output rank == 0. + // We can only verify the reduced dimension size to be 1 if this is not + // the special case of output rank == 0. if (outputRank != 0) { auto outputShape = outputType.getShape(); if (!outputType.isDynamicDim(reduceAxis) && diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 39699ee315e6cb3..0d55d1899c713e3 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s -// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s -// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s +// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s +// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s +// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s // CHECK-LABEL: @matmul func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) { @@ -521,7 +521,7 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens // CHECK-LABEL: @conv2d_i8 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () { - // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64> + // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32> // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32> // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) { @@ -542,7 +542,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK-LABEL: @conv2d_f32 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { - // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64> + // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32> // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32> diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir index c2bbfd5130ebcd0..73da2810abe0444 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir @@ -24,7 +24,7 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> { // check that tosa verify kick in func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> { - // expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}} + // expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 8e19f87dbf4aa8c..2902c4a62009e9f 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -80,14 +80,14 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32> - %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: %[[CST:.+]] = "tosa.const"() <{ // CHECK-SAME{LITERAL}: value = dense<[ // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] // CHECK-SAME{LITERAL}: ]> - %1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32> + %1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<3x1x4x2xi32> // CHECK: return %[[CST]] return %1 : tensor<3x1x4x2xi32> } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 418f7687b3cce86..414bcfe237d7535 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1,6 +1,22 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment +func.func @test_const() -> tensor<1xf32> { + // expected-error@+1{{'tosa.const' op expected same attr/result element types}} + %0 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- + +func.func @test_const_non_tensor_attr() { + // expected-error@+1{{tosa.const' op expected tensors for attr/result type}} + %0 = "tosa.const"() {value = dense<1.0> : vector} : () -> tensor + return +} + +// ----- + func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}} %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} @@ -148,6 +164,42 @@ func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) // ----- +func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> { + %perms = "tosa.const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}} + %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32> + return %1 : tensor<*xi32> +} + +// ----- + +func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> { + %perms = "tosa.const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}} + %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32> + return %1 : tensor<*xi32> +} + +// ----- + +func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> { + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}} + %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> + return %1 : tensor<3x4xi32> +} + +// ----- + +func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> { + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}} + %1 = tosa.transpose %arg0, %perms : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<3x4xi32> + return %1 : tensor<3x4xi32> +} + +// ----- + func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> { %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<273x3xf32> @@ -269,7 +321,7 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () { // ----- func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () { - // expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}} + // expected-error@+1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}} %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32> return } @@ -277,7 +329,7 @@ func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () // ----- func.func @test_reshape_zero_dim_input(%arg0 : tensor) -> () { - // expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}} + // expected-error@+1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor'}} %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<13x0x3xf32> return } @@ -341,7 +393,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { // ----- func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { - // expected-error@+1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}} + // expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}} %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> return %0 : tensor<1x27x27x16xf32> @@ -350,8 +402,8 @@ func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: // ----- func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { - // expected-error@+1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} + // expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> return %0 : tensor<1x27x27x16xf32> } @@ -360,7 +412,7 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor< // ----- func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> { - // expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}} + // expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> @@ -369,7 +421,7 @@ func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> // ----- func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> { - // expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}} + // expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> @@ -469,7 +521,7 @@ func.func @test_tile_io_rank_mismatch() { // CHECK-LABEL: @test_invalid_constant_permutation func.func @test_invalid_constant_permutation() { - // expected-error@+3 {{permutation must be within input bounds}} + // expected-error@+3 {{'tosa.transpose' op expected valid permutation tensor}} %0 = tensor.empty() : tensor<3x4x5xi32> %1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32> %2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32> @@ -480,7 +532,7 @@ func.func @test_invalid_constant_permutation() { // CHECK-LABEL: test_rank_size_constant_permutation func.func @test_rank_size_constant_permutation() { - // expected-error@+4 {{permutation must be within input bounds}} + // expected-error@+4 {{'tosa.transpose' op expected valid permutation tensor}} %0 = arith.constant 6 : index %1 = arith.constant dense<[0, 2]> : tensor<2xi32> %2 = tensor.empty(%0) : tensor @@ -492,7 +544,7 @@ func.func @test_rank_size_constant_permutation() { // CHECK-LABEL: test_large_constant_permutation func.func @test_large_constant_permutation() { - // expected-error@+4 {{permutation must be within input bounds}} + // expected-error@+4 {{'tosa.transpose' op expected valid permutation tensor}} %0 = arith.constant 6 : index %1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32> %2 = tensor.empty(%0) : tensor @@ -504,7 +556,7 @@ func.func @test_large_constant_permutation() { // CHECK-LABEL: test_table_rank0_table func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor) { - // expected-error@+1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor'}} + // expected-error@+1 {{'tosa.table' op operand #1 must be 1-d tosa-conformant tensor, but got 'tensor'}} %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor) -> tensor<64xi16> return } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 525ee917ccd9fd2..a1600fd33c54b46 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -573,6 +573,22 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { return %1 : tensor<3x13x21xf32> } +// ----- +// CHECK-LABEL: transpose_dynamic_dim +func.func @test_transpose_dynamic_dim(%arg0: tensor<13x?x3xf32>) -> tensor<3x13x?xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = tosa.transpose %arg0, %0 : (tensor<13x?x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32> + return %1 : tensor<3x13x?xf32> +} + +// ----- +// CHECK-LABEL: transpose_half_dynamic_dim +func.func @test_transpose_half_dynamic_dim(%arg0: tensor<13x3x3xf32>) -> tensor<3x13x?xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = tosa.transpose %arg0, %0 : (tensor<13x3x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32> + return %1 : tensor<3x13x?xf32> +} + // ----- // CHECK-LABEL: gather func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {