From ed7dc35bc7be26f24226c8fd8b00f91dbb1d3d74 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Sat, 29 Jan 2022 11:30:41 +0800 Subject: [PATCH 1/2] fix --- .../metal_kernel/texture/MaxKernel.metal | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/lite/backends/metal/metal_kernel/texture/MaxKernel.metal b/lite/backends/metal/metal_kernel/texture/MaxKernel.metal index 0fe0813ccdf..c107df440e1 100644 --- a/lite/backends/metal/metal_kernel/texture/MaxKernel.metal +++ b/lite/backends/metal/metal_kernel/texture/MaxKernel.metal @@ -77,3 +77,34 @@ kernel void arg_max_c(texture2d_array inTexture[[texture(0) outTexture.write(ftype4(index_r, index_g, index_b, index_a), gid.xy, gid.z); } } + +kernel void arg_max_h(texture2d_array inTexture[[texture(0)]], + texture2d_array outTexture[[texture(1)]], + constant ArgParam& param[[buffer(0)]], + uint3 gid[[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) + return; + + // dimensions = 4, CPU is NCHW, GPU is NHWC + if (param.orank == 4) { + int index = 0; +#if LITE_WITH_METAL_FULL + float omax = -FLT_MAX; +#else + float omax = -FLT_MAX; +#endif + uint iAL = inTexture.get_height(); + auto flag = bool4(false); + ftype4 guard_value = inTexture.read(uint2(gid.x, 0), gid.z); + int4 guard_index = int4(0); + for (uint i = 1; i < iAL; i++) { + ftype4 in = inTexture.read(uint2(gid.x, i), gid.z); + int4 idx = int4(i); + flag = bool4(guard_value >= in); + guard_value = select(in, guard_value, flag); + guard_index = select(idx, guard_index, flag); + } + outTexture.write(ftype4(guard_index), gid.xy, gid.z); + } +} \ No newline at end of file From a23d90c84272927ba100972e59fdc8f17e34869f Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Tue, 22 Feb 2022 11:19:37 +0800 Subject: [PATCH 2/2] fix_conv2d --- .../metal/image_op/conv2d_image_compute.mm | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/lite/kernels/metal/image_op/conv2d_image_compute.mm b/lite/kernels/metal/image_op/conv2d_image_compute.mm index 2cedb094a9e..6bc41e71331 100644 --- a/lite/kernels/metal/image_op/conv2d_image_compute.mm +++ b/lite/kernels/metal/image_op/conv2d_image_compute.mm @@ -110,19 +110,22 @@ if (metal_context_->use_mps()) { int input_c = static_cast(input_buffer_->tensor_dim_[1]); int output_c = static_cast(output_buffer_->tensor_dim_[1]); - // intput & output C channel must >=3 + // input channel must >=3 // attention: should be >=4, texture data layout is RGBA - if (input_c >= 3 && output_c >= 3) { - should_use_mps = true; + if (is_depthwise_) { + if (input_c >= 3 && output_c >= 3) { + should_use_mps = true; + } + } else { + if (input_c >= 3) { + should_use_mps = true; + } } } } if (IsWinoGrad(function_name_) || IsQuadruple(function_name_)) { should_use_mps = false; } - if (!is_depthwise_ && param.groups > 1) { - should_use_mps = false; - } if (param.bias) { if (!canMPSAddByChannel()) { should_use_mps = false; @@ -472,7 +475,7 @@ auto filter_h = static_cast(param.filter->dims()[2]); auto filter_w = static_cast(param.filter->dims()[3]); auto input_c = static_cast(input_buffer_->tensor_dim_[1]); - auto output_c = static_cast(output_buffer_->tensor_dim_[1]); + auto output_c = fmax(4, static_cast(output_buffer_->tensor_dim_[1])); MPSCNNConvolutionDescriptor* description = nil; if (is_depthwise_) { description = [MPSCNNDepthWiseConvolutionDescriptor @@ -491,6 +494,7 @@ description.strideInPixelsY = param.strides[1]; description.dilationRateX = (*param.dilations)[0]; description.dilationRateY = (*param.dilations)[1]; + if (!is_depthwise_) description.groups = param.groups; // active function switch (param.activation_param.active_type) { case lite_api::ActivationType::kRelu: {