Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] bug fix infer shape for slice #113497

Merged
merged 1 commit into from
Jan 22, 2025

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Oct 23, 2024

This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1

added tests to check:

  • size = -1
  • size is out of bound
  • start is out of bound

@llvmbot
Copy link
Member

llvmbot commented Oct 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1

added tests to check:

  • size = -1
  • size is out of bound
  • start is out of bound

Full diff: /~https://github.com/llvm/llvm-project/pull/113497.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+34-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+42)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..01312584652049 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -844,8 +844,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  inferredReturnShapes.push_back(
-      ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+  auto start = adaptor.getStart();
+  auto size = adaptor.getSize();
+
+  // if size[i] is -1, all remaining elements in dimension i are included
+  // in the slice, similar to TF.
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  // initialize outputShape to all unknown
+  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+  if (inputShape.hasRank()) {
+    for (size_t i = 0; i < size.size(); i++) {
+      if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+          (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+           start[i] < inputShape.getDimSize(i))) {
+        // size[i] is not 0 and not < -1, and start[i] is in valid range
+        if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+          // input shape has unknown dim[i] - only valid if size[i] > 0
+          if (size[i] > 0) {
+            outputShape[i] = size[i];
+          }
+        } else {
+          // input shape has known dim[i]
+          if (size[i] == -1) {
+            outputShape[i] = inputShape.getDimSize(i) - start[i];
+          } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+            // start[i] + size[i] is within bound of input shape's dim[i]
+            outputShape[i] = size[i];
+          }
+        }
+      }
+    }
+  } else {
+    outputShape = convertToMlirShape(size);
+  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+  // this checks following
+  //  dim 0: size=-1, input dim=? => inferred output dim is ?
+  //  dim 1: size=-1 => inferred output dim is input_dim - start
+  //  dim 2: size=-1, start=-1 => inferred output dim is ?
+  //  dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+  %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // this checks following
+  //  dim 0: size=0 => inferred output dim is ?
+  //  dim 1: size=-2 => inferred output dim is ?
+  //  dim 3: start+size out of bound because size too big: inferred output dim is ?
+  //  dim 4: size=4, input dim=? => inferred output dim is 4
+  %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // this checks following
+  //  dim 0: start=-1 => inferred output dim is ?
+  //  dim 1: start=8 => inferred output dim is ?
+  //  dim 2: start+size out of bound: inferred output dim is ?
+  //  dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+  %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_slice_dynamic
 func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
   // CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>

@GeorgeARM
Copy link
Contributor

@Tai78641 can we rebase?

@Tai78641 Tai78641 force-pushed the pr_fix_slice_infer_shape2 branch from db2a192 to 3c16454 Compare January 17, 2025 20:24
@Tai78641
Copy link
Contributor Author

@Tai78641 can we rebase?

done

This fixes the infer output shape of TOSA slice op for start/size values
that are out-of-bound or -1

added tests to check:
  - size = -1
  - size is out of bound
  - start is out of bound

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I8b59502a93cb332fe5c9a9f87970b83742538126
@Tai78641 Tai78641 force-pushed the pr_fix_slice_infer_shape2 branch from 3c16454 to cd57590 Compare January 17, 2025 22:22
@GeorgeARM GeorgeARM merged commit 7986e0c into llvm:main Jan 22, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants