From 7986e0cad10f3bf9efbbe31110ece68af5cb8751 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 22 Jan 2025 07:29:44 -0600 Subject: [PATCH] [TOSA] bug fix infer shape for slice (#113497) This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check: - size = -1 - size is out of bound - start is out of bound Signed-off-by: Tai Ly --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 36 +++++++++++++++- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 42 +++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 75b7da708cbb58..de5ff61b5848e3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -901,8 +901,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, SliceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - inferredReturnShapes.push_back( - ShapedTypeComponents(convertToMlirShape(adaptor.getSize()))); + auto start = adaptor.getStart(); + auto size = adaptor.getSize(); + + // if size[i] is -1, all remaining elements in dimension i are included + // in the slice, similar to TF. + ShapeAdaptor inputShape(adaptor.getInput1().getType()); + // initialize outputShape to all unknown + SmallVector outputShape(size.size(), ShapedType::kDynamic); + if (inputShape.hasRank()) { + for (size_t i = 0; i < size.size(); i++) { + if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 && + (ShapedType::isDynamic(inputShape.getDimSize(i)) || + start[i] < inputShape.getDimSize(i))) { + // size[i] is not 0 and not < -1, and start[i] is in valid range + if (ShapedType::isDynamic(inputShape.getDimSize(i))) { + // input shape has unknown dim[i] - only valid if size[i] > 0 + if (size[i] > 0) { + outputShape[i] = size[i]; + } + } else { + // input shape has known dim[i] + if (size[i] == -1) { + outputShape[i] = inputShape.getDimSize(i) - start[i]; + } else if (start[i] + size[i] <= inputShape.getDimSize(i)) { + // start[i] + size[i] is within bound of input shape's dim[i] + outputShape[i] = size[i]; + } + } + } + } + } else { + outputShape = convertToMlirShape(size); + } + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 8ab7284019f965..44cc6acd7e97a0 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor) -> () { // ----- +// CHECK-LABEL: @test_slice_size_minus_one +func.func @test_slice_size_minus_one(%arg0 : tensor) -> () { + // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor) -> tensor + // this checks following + // dim 0: size=-1, input dim=? => inferred output dim is ? + // dim 1: size=-1 => inferred output dim is input_dim - start + // dim 2: size=-1, start=-1 => inferred output dim is ? + // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound + %2= tosa.slice %arg0 { start = array, size = array } : (tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_slice_size_out_of_bound +func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () { + // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor<8x8x8x?xi32>) -> tensor + // this checks following + // dim 0: size=0 => inferred output dim is ? + // dim 1: size=-2 => inferred output dim is ? + // dim 3: start+size out of bound because size too big: inferred output dim is ? + // dim 4: size=4, input dim=? => inferred output dim is 4 + %2= tosa.slice %arg0 { start = array, size = array } : (tensor<8x8x8x?xi32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_slice_start_out_of_bound +func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () { + // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor<8x8x8x?xi32>) -> tensor + // this checks following + // dim 0: start=-1 => inferred output dim is ? + // dim 1: start=8 => inferred output dim is ? + // dim 2: start+size out of bound: inferred output dim is ? + // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4 + %2= tosa.slice %arg0 { start = array, size = array } : (tensor<8x8x8x?xi32>) -> tensor + return +} + +// ----- + // CHECK-LABEL: @test_slice_dynamic func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () { // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>