-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOSA] bug fix infer shape for slice #113497
Merged
Merged
+76
−2
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check:
Full diff: /~https://github.com/llvm/llvm-project/pull/113497.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..01312584652049 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -844,8 +844,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(
- ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+ auto start = adaptor.getStart();
+ auto size = adaptor.getSize();
+
+ // if size[i] is -1, all remaining elements in dimension i are included
+ // in the slice, similar to TF.
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ // initialize outputShape to all unknown
+ SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+ if (inputShape.hasRank()) {
+ for (size_t i = 0; i < size.size(); i++) {
+ if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+ (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+ start[i] < inputShape.getDimSize(i))) {
+ // size[i] is not 0 and not < -1, and start[i] is in valid range
+ if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+ // input shape has unknown dim[i] - only valid if size[i] > 0
+ if (size[i] > 0) {
+ outputShape[i] = size[i];
+ }
+ } else {
+ // input shape has known dim[i]
+ if (size[i] == -1) {
+ outputShape[i] = inputShape.getDimSize(i) - start[i];
+ } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+ // start[i] + size[i] is within bound of input shape's dim[i]
+ outputShape[i] = size[i];
+ }
+ }
+ }
+ }
+ } else {
+ outputShape = convertToMlirShape(size);
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+ // this checks following
+ // dim 0: size=-1, input dim=? => inferred output dim is ?
+ // dim 1: size=-1 => inferred output dim is input_dim - start
+ // dim 2: size=-1, start=-1 => inferred output dim is ?
+ // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+ %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: size=0 => inferred output dim is ?
+ // dim 1: size=-2 => inferred output dim is ?
+ // dim 3: start+size out of bound because size too big: inferred output dim is ?
+ // dim 4: size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: start=-1 => inferred output dim is ?
+ // dim 1: start=8 => inferred output dim is ?
+ // dim 2: start+size out of bound: inferred output dim is ?
+ // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_slice_dynamic
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
|
@Tai78641 can we rebase? |
Tai78641
force-pushed
the
pr_fix_slice_infer_shape2
branch
from
January 17, 2025 20:24
db2a192
to
3c16454
Compare
done |
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check: - size = -1 - size is out of bound - start is out of bound Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I8b59502a93cb332fe5c9a9f87970b83742538126
Tai78641
force-pushed
the
pr_fix_slice_infer_shape2
branch
from
January 17, 2025 22:22
3c16454
to
cd57590
Compare
GeorgeARM
approved these changes
Jan 17, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1
added tests to check: