diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index 8693f3c8f1172c..8e9e3982f2a02f 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -55,45 +55,35 @@ struct offsets_and_ratios { template std::vector> get_indexes_and_ratios( - std::size_t width // width - , - std::size_t height // , height - , - const T scaled_w // , roi_width - , - const T scaled_h // , roi_height - , - const T scaled_x // , roi_xmin - , - const T scaled_y // , roi_ymin - , - std::size_t mpx // , pooled_width - , - std::size_t mix // , roi_bin_grid_w - , - std::size_t mpy // , pooled_height - , - std::size_t miy // , roi_bin_grid_h - ) { - const auto ind_num = mpx * mix * mpy * miy; + std::size_t width, + std::size_t height, + const T roi_width, + const T roi_height, + const T roi_xmin, + const T roi_ymin, + std::size_t pooled_width, + std::size_t roi_bin_grid_w, + std::size_t pooled_height, + std::size_t roi_bin_grid_h) { + const auto ind_num = pooled_width * roi_bin_grid_w * pooled_height * roi_bin_grid_h; std::vector> interpolation_cords; interpolation_cords.reserve(ind_num); - const auto bin_w = scaled_w / mpx; - const auto bin_h = scaled_h / mpy; + const auto bin_w = roi_width / pooled_width; + const auto bin_h = roi_height / pooled_height; - for (std::size_t py = 0; py < mpy; py++) { - for (std::size_t px = 0; px < mpx; px++) { - for (std::size_t iy = 0; iy < miy; iy++) { + for (std::size_t py = 0; py < pooled_height; py++) { + for (std::size_t px = 0; px < pooled_width; px++) { + for (std::size_t iy = 0; iy < roi_bin_grid_h; iy++) { // calculate x of sample points - auto y = scaled_y + - bin_h * (py + static_cast(iy + .5f) / static_cast(miy)); - for (std::size_t ix = 0; ix < mix; ix++) { + auto y = roi_ymin + + bin_h * (py + static_cast(iy + .5f) / static_cast(roi_bin_grid_h)); + for (std::size_t ix = 0; ix < roi_bin_grid_w; ix++) { // calculate x of sample points auto x = - scaled_x + - bin_w * (px + static_cast(ix + .5f) / static_cast(mix)); + roi_xmin + + bin_w * (px + static_cast(ix + .5f) / static_cast(roi_bin_grid_w)); // deal with elements out of map if (y < -1.0 || y > height || x < -1.0 || x > width) {