From 2cd70b61c1b9268a1603632a835c07cf38b92be1 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Tue, 26 Oct 2021 20:50:34 +0800 Subject: [PATCH 1/2] fix_split --- .../metal/metal_kernel/texture/Common.metal | 2 + .../metal_kernel/texture/Split.inc.metal | 257 ------------------ .../metal/metal_kernel/texture/Split.metal | 250 ++++++++++++----- lite/kernels/metal/image_op/metal_params.h | 2 + .../metal/image_op/split_image_compute.mm | 24 +- 5 files changed, 206 insertions(+), 329 deletions(-) delete mode 100644 lite/backends/metal/metal_kernel/texture/Split.inc.metal diff --git a/lite/backends/metal/metal_kernel/texture/Common.metal b/lite/backends/metal/metal_kernel/texture/Common.metal index 8e52e1af0dd..733dc10efab 100644 --- a/lite/backends/metal/metal_kernel/texture/Common.metal +++ b/lite/backends/metal/metal_kernel/texture/Common.metal @@ -199,6 +199,8 @@ struct SplitParam { int32_t idim[4]; int32_t axis; int32_t offset; + int32_t num; + int32_t v_; int32_t trans[4]; int32_t vdim[4]; }; diff --git a/lite/backends/metal/metal_kernel/texture/Split.inc.metal b/lite/backends/metal/metal_kernel/texture/Split.inc.metal deleted file mode 100644 index 97455c70329..00000000000 --- a/lite/backends/metal/metal_kernel/texture/Split.inc.metal +++ /dev/null @@ -1,257 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#ifdef P - -#define CONCAT2(a, b) a##b -#define CONCAT2_(a, b) a##_##b -#define CONCAT3_(a, b, c) a##_##b##_##c -#define CONCAT4_(a, b, c, d) a##_##b##_##c##_##d -#define CONCAT5_(a, b, c, d, e) a##_##b##_##c##_##d##_##e - -#define FUNC(f, r, n, v) CONCAT4_(f, r, n, v) -#define VECTOR(p, n) CONCAT2(p, n) -#define FUNC_R(f, r) CONCAT2_(f, r) - -#if V == VX -#define VV x -#elif V == VY -#define VV y -#elif V == VZ -#define VV z -#elif V == VZZ -#define VV zz -#else -#define VV normal -#endif - -#if V == VY -kernel void FUNC(split, R, N, VV)(texture2d_array input[[texture(0)]], - texture2d_array out1[[texture(1)]], - texture2d_array out2[[texture(2)]], -#if N >= 3 - texture2d_array out3[[texture(3)]], -#endif // N >= 3 -#if N >= 4 - texture2d_array out4[[texture(4)]], -#endif // N >= 4 - constant SplitParam& sp[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - - VECTOR(P, 4) r = input.read(gid.xy, gid.z); - int y = gid.y - sp.offset; - if (y < sp.vdim[0]) { - out1.write(r, gid.xy, gid.z); - return; - } - y -= sp.vdim[0]; - if (y < sp.vdim[1]) { - out2.write(r, uint2(gid.x, y), gid.z); - return; - } -#if N >= 3 - y -= sp.vdim[1]; - if (y < sp.vdim[2]) { - out3.write(r, uint2(gid.x, y), gid.z); - return; - } -#endif // N >= 3 -#if N >= 4 - y -= sp.vdim[2]; - if (y < sp.vdim[3]) { - out4.write(r, uint2(gid.x, y), gid.z); - return; - } -#endif // N >= 4 -} -#endif // V == VY - -#if V == VX -kernel void FUNC(split, R, N, VV)(texture2d_array input[[texture(0)]], - texture2d_array out1[[texture(1)]], - texture2d_array out2[[texture(2)]], -#if N >= 3 - texture2d_array out3[[texture(3)]], -#endif // N >= 3 -#if N >= 4 - texture2d_array out4[[texture(4)]], -#endif // N >= 4 - constant SplitParam& sp[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - VECTOR(P, 4) r = input.read(gid.xy, gid.z); - int x = gid.x; - if (x < sp.vdim[0]) { - out1.write(r, gid.xy, gid.z); - return; - } - x -= sp.vdim[0]; - if (x < sp.vdim[1]) { - out2.write(r, uint2(x, gid.y), gid.z); - return; - } -#if N >= 3 - x -= sp.vdim[1]; - if (x < sp.vdim[2]) { - out3.write(r, uint2(x, gid.y), gid.z); - return; - } -#endif // N >= 3 -#if N >= 4 - x -= sp.vdim[2]; - if (x < sp.vdim[3]) { - out4.write(r, uint2(x, gid.y), gid.z); - return; - } -#endif // N >= 4 -} -#endif // V == VX - -#if V == VZ -kernel void FUNC(split, R, N, VV)(texture2d_array input[[texture(0)]], - texture2d_array out1[[texture(1)]], - texture2d_array out2[[texture(2)]], -#if N >= 3 - texture2d_array out3[[texture(3)]], -#endif // N >= 3 -#if N >= 4 - texture2d_array out4[[texture(4)]], -#endif // N >= 4 - constant SplitParam& sp[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - VECTOR(P, 4) r = input.read(gid.xy, gid.z); - int z = gid.z; - if (z < sp.vdim[0]) { - out1.write(r, gid.xy, z); - return; - } - z -= sp.vdim[0]; - if (z < sp.vdim[1]) { - out2.write(r, gid.xy, z); - return; - } -#if N >= 3 - z -= sp.vdim[1]; - if (z < sp.vdim[2]) { - out3.write(r, gid.xy, z); - return; - } -#endif // N >= 3 -#if N >= 4 - z -= sp.vdim[2]; - if (z < sp.vdim[3]) { - out4.write(r, gid.xy, z); - return; - } -#endif // N >= 4 -} -#endif // V == VZ - -#if V == VZZ -kernel void FUNC(split, R, N, VV)(texture2d_array input[[texture(0)]], - texture2d_array out1[[texture(1)]], - texture2d_array out2[[texture(2)]], -#if N >= 3 - texture2d_array out3[[texture(3)]], -#endif // N >= 3 -#if N >= 4 - texture2d_array out4[[texture(4)]], -#endif // N >= 4 - constant SplitParam& sp[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - int index = 0; - int z = gid.z; - if (z - (sp.vdim[0] + 3) / 4 < 0) { // output1 - VECTOR(P, 4) r = input.read(gid.xy, z); - int len = (gid.z + 1) * 4 - sp.vdim[0]; - for (int i = 0; i < len; i++) { - r[3 - i] = 0; - } - out1.write(r, gid.xy, gid.z); - return; - } - z -= (sp.vdim[0] + 3) / 4; - if (z - (sp.vdim[1] + 3) / 4 < 0) { - int z_origin = z * 4 + sp.vdim[0]; - int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] - 1); - VECTOR(P, 4) r; - r[0] = 0; - r[1] = 0; - r[2] = 0; - r[3] = 0; - VECTOR(P, 4) r1 = input.read(gid.xy, z_origin / 4); - int start = z_origin % 4; - for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { - r[i - start] = r1[i]; - } - r1 = input.read(gid.xy, z_end / 4); - int end = z_end % 4; - for (int i = end; i >= 0 && end - i <= z_end - z_origin; i--) { - r[z_end - z_origin + i - end] = r1[i]; - } - out2.write(r, gid.xy, z); - return; - } -#if N >= 3 - z -= (sp.vdim[1] + 3) / 4; - if (z - (sp.vdim[2] + 3) / 4 < 0) { - int z_origin = z * 4 + sp.vdim[0] + sp.vdim[1]; - int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] + sp.vdim[2] - 1); - VECTOR(P, 4) r; - r[0] = 0; - r[1] = 0; - r[2] = 0; - r[3] = 0; - VECTOR(P, 4) r1 = input.read(gid.xy, z_origin / 4); - int start = z_origin % 4; - for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { - r[i - start] = r1[i]; - } - r1 = input.read(gid.xy, z_end / 4); - int end = z_end % 4; - for (int i = end; i >= 0 && end - i <= z_end - z_origin; i--) { - r[z_end - z_origin + i - end] = r1[i]; - } - out3.write(r, gid.xy, z); - return; - } -#endif // N >= 3 -#if N >= 4 - z -= (sp.vdim[2] + 2) / 4; - if (z - (sp.vdim[3] + 2) / 4 < 0) { - int z_origin = z * 4 + sp.vdim[0] + sp.vdim[1] + sp.vdim[2]; - int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] + sp.vdim[2] + sp.vdim[3] - 1); - VECTOR(P, 4) r; - r[0] = 0; - r[1] = 0; - r[2] = 0; - r[3] = 0; - VECTOR(P, 4) r1 = input.read(gid.xy, z_origin / 4); - int start = z_origin % 4; - for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { - r[i - start] = r1[i]; - } - r1 = input.read(gid.xy, z_end / 4); - int end = z_end % 4; - for (int i = end; i >= 0 && end - i <= z_end - z_origin; i--) { - r[z_end - z_origin + i - end] = r1[i]; - } - out4.write(r, gid.xy, z); - return; - } -#endif // N >= 4 -} -#endif // V == VZZ - -#undef VV -#endif diff --git a/lite/backends/metal/metal_kernel/texture/Split.metal b/lite/backends/metal/metal_kernel/texture/Split.metal index 66de0ff8186..53338186d52 100644 --- a/lite/backends/metal/metal_kernel/texture/Split.metal +++ b/lite/backends/metal/metal_kernel/texture/Split.metal @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 ftypeaddleftypeaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,72 +18,186 @@ using namespace metal; -#define VNORMAL 1 -#define VX 2 -#define VY 3 -#define VZ 4 +kernel void split(texture2d_array input[[texture(0)]], + texture2d_array out1[[texture(1)]], + texture2d_array out2[[texture(2)]], + texture2d_array out3[[texture(3)]], + texture2d_array out4[[texture(4)]], + constant SplitParam& sp[[buffer(0)]], + uint3 gid[[thread_position_in_grid]]) { + int n = sp.num; + int v_ = sp.v_; + ftype4 r = input.read(gid.xy, gid.z); + if (v_ == 1) { + int x = gid.x - sp.offset; + ; + if (x < sp.vdim[0]) { + out1.write(r, gid.xy, gid.z); + return; + } + x -= sp.vdim[0]; + if (x < sp.vdim[1]) { + out2.write(r, uint2(x, gid.y), gid.z); + return; + } + if (n >= 3) { + x -= sp.vdim[1]; + if (x < sp.vdim[2]) { + out3.write(r, uint2(x, gid.y), gid.z); + return; + } + } + if (n >= 4) { + x -= sp.vdim[2]; + if (x < sp.vdim[3]) { + out4.write(r, uint2(x, gid.y), gid.z); + return; + } + } + } else if (v_ == 2) { + int y = gid.y - sp.offset; + if (y < sp.vdim[0]) { + out1.write(r, gid.xy, gid.z); + return; + } + y -= sp.vdim[0]; + if (y < sp.vdim[1]) { + out2.write(r, uint2(gid.x, y), gid.z); + return; + } + if (n >= 3) { + y -= sp.vdim[1]; + if (y < sp.vdim[2]) { + out3.write(r, uint2(gid.x, y), gid.z); + return; + } + } + if (n >= 4) { + y -= sp.vdim[2]; + if (y < sp.vdim[3]) { + out4.write(r, uint2(gid.x, y), gid.z); + return; + } + } + } else if (v_ == 3) { + int z = gid.z; + if (z < sp.vdim[0]) { + out1.write(r, gid.xy, z); + return; + } + z -= sp.vdim[0]; + if (z < sp.vdim[1]) { + out2.write(r, gid.xy, z); + return; + } + if (n >= 3) { + z -= sp.vdim[1]; + if (z < sp.vdim[2]) { + out3.write(r, gid.xy, z); + return; + } + } + if (n >= 4) { + z -= sp.vdim[2]; + if (z < sp.vdim[3]) { + out4.write(r, gid.xy, z); + return; + } + } + } +} -// only support split_{2, 3, 4}_{2, 3, 4}_y_{float, half} -// only support split_{3, 4}_{2, 3, 4}_x_{float, half} -//// ssd-ar: (R=3, N=2, V=y) -#define V VY -#define R 3 -#define N 2 -#define P ftype -#include "Split.inc.metal" -#undef P -#undef N -#undef R -#undef V - -//// ssd-ar: (R=2, N=2, V=y) -#define V VY -#define R 2 -#define N 2 -#define P ftype -#include "Split.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZZ -#define R 4 -#define N 2 -#define P ftype -#include "Split.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 4 -#define N 2 -#define P ftype -#include "Split.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 4 -#define N 3 -#define P ftype -#include "Split.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 4 -#define N 4 -#define P ftype -#include "Split.inc.metal" -#undef P -#undef N -#undef R -#undef V +kernel void split_zz(texture2d_array input[[texture(0)]], + texture2d_array out1[[texture(1)]], + texture2d_array out2[[texture(2)]], + texture2d_array out3[[texture(3)]], + texture2d_array out4[[texture(4)]], + constant SplitParam& sp[[buffer(0)]], + uint3 gid[[thread_position_in_grid]]) { + int n = sp.num; + int v_ = sp.v_; + if (v_ == 4) { + int z = gid.z; + if (z - (sp.vdim[0] + 3) / 4 < 0) { // output1 + ftype4 r = input.read(gid.xy, z); + int len = (gid.z + 1) * 4 - sp.vdim[0]; + for (int i = 0; i < len; i++) { + r[3 - i] = 0; + } + out1.write(r, gid.xy, gid.z); + return; + } + z -= (sp.vdim[0] + 3) / 4; + if (z - (sp.vdim[1] + 3) / 4 < 0) { + int z_origin = z * 4 + sp.vdim[0]; + int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] - 1); + ftype4 r; + r[0] = 0; + r[1] = 0; + r[2] = 0; + r[3] = 0; + ftype4 r1 = input.read(gid.xy, z_origin / 4); + int start = z_origin % 4; + for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { + r[i - start] = r1[i]; + } + r1 = input.read(gid.xy, z_end / 4); + int end = z_end % 4; + for (int i = end; i >= 0 && end - i <= z_end - z_origin; i--) { + r[z_end - z_origin + i - end] = r1[i]; + } + out2.write(r, gid.xy, z); + return; + } + if (n >= 3) { + z -= (sp.vdim[1] + 3) / 4; + if (z - (sp.vdim[2] + 3) / 4 < 0) { + int z_origin = z * 4 + sp.vdim[0] + sp.vdim[1]; + int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] + sp.vdim[2] - 1); + ftype4 r; + r[0] = 0; + r[1] = 0; + r[2] = 0; + r[3] = 0; + ftype4 r1 = input.read(gid.xy, z_origin / 4); + int start = z_origin % 4; + for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { + r[i - start] = r1[i]; + } + r1 = input.read(gid.xy, z_end / 4); + int end = z_end % 4; + for (int i = end; i >= 0 && end - i <= z_end - z_origin; i--) { + r[z_end - z_origin + i - end] = r1[i]; + } + out3.write(r, gid.xy, z); + return; + } + } + if (n >= 4) { + z -= (sp.vdim[2] + 2) / 4; + if (z - (sp.vdim[3] + 2) / 4 < 0) { + int z_origin = z * 4 + sp.vdim[0] + sp.vdim[1] + sp.vdim[2]; + int z_end = + min(z_origin + 3, sp.vdim[0] + sp.vdim[1] + sp.vdim[2] + sp.vdim[3] - 1); + ftype4 r; + r[0] = 0; + r[1] = 0; + r[2] = 0; + r[3] = 0; + ftype4 r1 = input.read(gid.xy, z_origin / 4); + int start = z_origin % 4; + for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { + r[i - start] = r1[i]; + } + r1 = input.read(gid.xy, z_end / 4); + int end = z_end % 4; + for (int i = end; i >= 0 && end - i <= z_end - z_origin; i--) { + r[z_end - z_origin + i - end] = r1[i]; + } + out4.write(r, gid.xy, z); + return; + } + } + } +} diff --git a/lite/kernels/metal/image_op/metal_params.h b/lite/kernels/metal/image_op/metal_params.h index e72918d2692..0ae1925b07e 100644 --- a/lite/kernels/metal/image_op/metal_params.h +++ b/lite/kernels/metal/image_op/metal_params.h @@ -216,6 +216,8 @@ struct SplitMetalParam { int idim[4]; int axis; int offset; + int num; + int v_; int trans[4]; int vdim[4]; }; diff --git a/lite/kernels/metal/image_op/split_image_compute.mm b/lite/kernels/metal/image_op/split_image_compute.mm index 54ba93e63d8..90cb41e10e9 100644 --- a/lite/kernels/metal/image_op/split_image_compute.mm +++ b/lite/kernels/metal/image_op/split_image_compute.mm @@ -73,7 +73,8 @@ const auto& param = this->Param(); auto outputs = param.output; - size_t num = outputs.size(); + int num = outputs.size(); + int vaxis = 0; int irank = (int)input_buffer_->tensor_dim_.size(); // intput dims: CPU NCHW @@ -146,18 +147,33 @@ throw std::logic_error("ERROR: unsupported split type"); } + if (v_ == "normal") + vaxis = 0; + else if (v_ == "x") + vaxis = 1; + else if (v_ == "y") + vaxis = 2; + else if (v_ == "z") + vaxis = 3; + else if (v_ == "zz") + vaxis = 4; + SplitMetalParam metal_param = {{idm[0], idm[1], idm[2], idm[3]}, static_cast(axis), 0, + num, + vaxis, {trans[0], trans[1], trans[2], trans[3]}, {(int)vdim[0], (int)vdim[1], (int)vdim[2], (int)vdim[3]}}; params_buffer_ = std::make_shared(metal_context_, sizeof(metal_param), &metal_param); - std::string function_name = - "split_" + std::to_string(irank) + "_" + std::to_string(num) + "_" + v_; - function_name_ = function_name; + if (v_ == "zz") + function_name_ = "split_zz"; + else + function_name_ = "split"; + // pipline auto backend = (__bridge MetalContextImp*)metal_context_->backend(); pipline_ = [backend pipline:function_name_]; From 63704f671813f6eff215dc582e8d983fa06011d5 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Mon, 1 Nov 2021 19:53:01 +0800 Subject: [PATCH 2/2] Update Split.metal --- .../metal/metal_kernel/texture/Split.metal | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/lite/backends/metal/metal_kernel/texture/Split.metal b/lite/backends/metal/metal_kernel/texture/Split.metal index 53338186d52..1b928274a01 100644 --- a/lite/backends/metal/metal_kernel/texture/Split.metal +++ b/lite/backends/metal/metal_kernel/texture/Split.metal @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 ftypeaddleftypeaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -132,11 +132,7 @@ kernel void split_zz(texture2d_array input[[texture(0)]], if (z - (sp.vdim[1] + 3) / 4 < 0) { int z_origin = z * 4 + sp.vdim[0]; int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] - 1); - ftype4 r; - r[0] = 0; - r[1] = 0; - r[2] = 0; - r[3] = 0; + ftype4 r = 0; ftype4 r1 = input.read(gid.xy, z_origin / 4); int start = z_origin % 4; for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { @@ -155,11 +151,7 @@ kernel void split_zz(texture2d_array input[[texture(0)]], if (z - (sp.vdim[2] + 3) / 4 < 0) { int z_origin = z * 4 + sp.vdim[0] + sp.vdim[1]; int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] + sp.vdim[2] - 1); - ftype4 r; - r[0] = 0; - r[1] = 0; - r[2] = 0; - r[3] = 0; + ftype4 r = 0; ftype4 r1 = input.read(gid.xy, z_origin / 4); int start = z_origin % 4; for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) { @@ -180,11 +172,7 @@ kernel void split_zz(texture2d_array input[[texture(0)]], int z_origin = z * 4 + sp.vdim[0] + sp.vdim[1] + sp.vdim[2]; int z_end = min(z_origin + 3, sp.vdim[0] + sp.vdim[1] + sp.vdim[2] + sp.vdim[3] - 1); - ftype4 r; - r[0] = 0; - r[1] = 0; - r[2] = 0; - r[3] = 0; + ftype4 r = 0; ftype4 r1 = input.read(gid.xy, z_origin / 4); int start = z_origin % 4; for (int i = start; i < 4 && i - start <= z_end - z_origin; i++) {