diff --git a/lite/kernels/xpu/__xpu__squeeze_excitation_compute.cc b/lite/kernels/xpu/__xpu__squeeze_excitation_compute.cc index 6021555f8ea..94e2c64ed77 100644 --- a/lite/kernels/xpu/__xpu__squeeze_excitation_compute.cc +++ b/lite/kernels/xpu/__xpu__squeeze_excitation_compute.cc @@ -24,6 +24,7 @@ namespace kernels { namespace xpu { void XPUSqueezeExcitationCompute::PrepareForRun() { + auto& ctx = this->ctx_->As(); auto& param = this->template Param(); auto weight_ptr = param.filter->data(); auto weight_len = param.filter->numel(); @@ -35,21 +36,24 @@ void XPUSqueezeExcitationCompute::PrepareForRun() { paddle::lite::xpu::math::FindMaxAbs(weight_ptr, weight1_len); float weight2_max = paddle::lite::xpu::math::FindMaxAbs( weight_ptr + weight1_len, weight2_len); - std::vector weight_1_max_v(4, weight1_max); + int max_ptr_size = get_max_ptr_size(ctx.GetRawContext()); + std::vector weight_1_max_v(max_ptr_size, weight1_max); - weight1_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); + weight1_max_guard_ = + TargetWrapperXPU::MallocScratchPad(max_ptr_size * sizeof(float)); weight1_maxptr_ = reinterpret_cast(weight1_max_guard_->addr_); XPU_CALL(xpu_memcpy(weight1_maxptr_, weight_1_max_v.data(), - 4 * sizeof(float), + max_ptr_size * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - std::vector weight_2_max_v(4, weight2_max); - weight2_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); + std::vector weight_2_max_v(max_ptr_size, weight2_max); + weight2_max_guard_ = + TargetWrapperXPU::MallocScratchPad(max_ptr_size * sizeof(float)); weight2_maxptr_ = reinterpret_cast(weight2_max_guard_->addr_); XPU_CALL(xpu_memcpy(weight2_maxptr_, weight_2_max_v.data(), - 4 * sizeof(float), + max_ptr_size * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); // quant