Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TOSA] add additional verification to TOSA #108133

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)

set(LLVM_TARGET_DEFINITIONS TosaOps.td)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)

set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
Expand Down
58 changes: 27 additions & 31 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,

Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
Expand Down Expand Up @@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
4DTensorOf<[Tosa_Weight]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand Down Expand Up @@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {

let arguments = (ins
Tosa_Tensor5D:$input,
TensorRankOf<[Tosa_Weight], [5]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
Expand Down Expand Up @@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
4DTensorOf<[Tosa_Weight]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand Down Expand Up @@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {

let arguments = (ins
Tosa_Tensor2D:$input,
2DTensorOf<[Tosa_Weight]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
);
Expand Down Expand Up @@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
4DTensorOf<[Tosa_Weight]>:$filter,
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttrUpto4:$out_shape,
Expand Down Expand Up @@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
}];

let arguments = (ins
I1Tensor:$input1,
I1Tensor:$input2
Tosa_I1Tensor:$input1,
Tosa_I1Tensor:$input2
);

let results = (outs
I1Tensor:$z
Tosa_I1Tensor:$z
);
}

Expand Down Expand Up @@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
}];

let arguments = (ins
I1Tensor:$input1,
I1Tensor:$input2
Tosa_I1Tensor:$input1,
Tosa_I1Tensor:$input2
);

let results = (outs
I1Tensor:$z
Tosa_I1Tensor:$z
);
}

Expand All @@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
}];

let arguments = (ins
I1Tensor:$input1,
I1Tensor:$input2
Tosa_I1Tensor:$input1,
Tosa_I1Tensor:$input2
);

let results = (outs
I1Tensor:$z
Tosa_I1Tensor:$z
);
}

Expand Down Expand Up @@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
}];

let arguments = (ins
I1Tensor:$input1
Tosa_I1Tensor:$input1
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);
}

Expand Down Expand Up @@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];

let arguments = (ins
I1Tensor:$pred,
Tosa_I1Tensor:$pred,
Tosa_Tensor:$on_true,
Tosa_Tensor:$on_false
);
Expand Down Expand Up @@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);

let extraClassDeclaration = [{
Expand Down Expand Up @@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);

let hasFolder = 1;
Expand All @@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);

let hasFolder = 1;
Expand Down Expand Up @@ -1721,15 +1716,15 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",

let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Int32Or64Tensor:$perms
Tosa_Int32Tensor:$perms
);

let results = (
outs Tosa_Tensor:$output
);

let extraClassDeclaration = [{
LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
}];

let hasCanonicalizer = 1;
Expand All @@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {

let arguments = (ins
Tosa_Tensor3D:$values,
2DTensorOf<[Tosa_Int32]>:$indices
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
);

let results = (outs
Expand All @@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {

let arguments = (ins
Tosa_Tensor3D:$values_in,
2DTensorOf<[Tosa_Int32]>:$indices,
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
Tosa_Tensor3D:$input
);

Expand Down Expand Up @@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);

let results = (outs
TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
);

let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];

let arguments = (ins
I1Tensor:$cond,
Tosa_I1Tensor:$cond,
Variadic<Tosa_Tensor>:$inputs
);

Expand Down
61 changes: 43 additions & 18 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;

//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
//===----------------------------------------------------------------------===//

def HasNo0Dimensions : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;

class TosaTensorOf<
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;

class TosaRankedTensorOf<
list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
: RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;

class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
: UnrankedTensorOf<allowedTypes, preds, summary>;

class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
: TosaRankedTensorOf<allowedTypes,
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;

//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//

def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
def Tosa_I1Tensor : TosaTensorOf<[I1]>;
def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;

def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;

// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;

// Must be ranked but no further constraints
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;

// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
AnyFloat.predicate]>, "tosa.dtype">;

class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;

//===----------------------------------------------------------------------===//
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//

// Rank-0 (scalar) tensor
def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;

// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
// they should be shape propagate used Tosa's shape inference pass and verified
// to not include any remaining unranked tensors.
def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;

def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;

// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
def Tosa_Tensor1Dto6D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;

def Tosa_TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;

def Tosa_Int32TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;

//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
Expand All @@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
AnyTypeOf<types>.predicate,
VectorOf<types>.predicate,
TensorOf<types>.predicate]>,
TosaTensorOf<types>.predicate]>,
description>;

def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
}

// Apply an int32_t permutation to some input, that should be of the same
// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
template <typename T>
SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
ArrayRef<int32_t> perms) {
SmallVector<T> permuted;
size_t N = input.size();
permuted.resize_for_overwrite(N);
for (size_t i = 0; i < N; i++)
permuted[i] = input[perms[i]];
return permuted;
}

} // namespace tosa
} // namespace mlir

Expand Down
Loading
Loading