Skip to content

Commit

Permalink
modify squeeze handle axes.size=0
Browse files Browse the repository at this point in the history
  • Loading branch information
shentanyue committed Jul 5, 2022
1 parent 2a06a6e commit 893de00
Showing 1 changed file with 33 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -718,62 +718,55 @@ void NCHW2NHWCDataLayoutConverter::ConvertSqueeze(core::Operation* operation) {
auto axes_operand = input_operands[1];
// Recalculate the perm according to the dimorder vector of the input operand
auto input_permutation = GetPermutation(input_operand);
std::vector<int32_t> origin_data_layout =
std::vector<int32_t> identity_permutation =
IdentityPermutation(input_dimensions_count);
std::vector<int32_t> axes;
if (axes_operand && (axes_operand->length / sizeof(int32_t)) > 0) {
auto axes_count = axes_operand->length / sizeof(int32_t);
auto axes_data = reinterpret_cast<int32_t*>(axes_operand->buffer);
axes = std::vector<int32_t>(axes_data, axes_data + axes_count);
// Recalculate the axes according to the dimorder vector of the input
// operand
for (int32_t i = 0; i < axes_count; i++) {
if (axes_data[i] < 0) {
axes_data[i] += input_dimensions_count;
}
TransposeAxis(axes_data[i], input_permutation);
}
} else {
for (int32_t i = 0; i < input_dimensions_count; i++) {
if (input_operand->type.dimensions.data[i] == 1) {
axes.push_back(i);
// Delete the dimension corresponding to the axis of the
// identity_permutation
for (auto it = identity_permutation.begin();
it != identity_permutation.end();) {
if (*it == axes_data[i]) {
it = identity_permutation.erase(it);
} else {
++it;
}
}
}
}
for (int32_t i = 0; i < axes.size(); i++) {
if (axes[i] < 0) {
axes[i] += input_dimensions_count;
}
// Delete the dimension corresponding to the axis of the origin_data_layout
for (auto it = origin_data_layout.begin();
it != origin_data_layout.end();) {
if (*it == axes[i]) {
it = origin_data_layout.erase(it);
} else {
++it;
// Delete the dimension corresponding to the axis of the input_permutation
TransposeAxis(axes_data[i], input_permutation);
for (auto it = input_permutation.begin();
it != input_permutation.end();) {
if (*it == axes_data[i]) {
it = input_permutation.erase(it);
} else {
++it;
}
}
}
TransposeAxis(axes[i], input_permutation);
// Delete the dimension corresponding to the axis of the input_permutation
for (auto it = input_permutation.begin(); it != input_permutation.end();) {
if (*it == axes[i]) {
it = input_permutation.erase(it);
} else {
++it;
}
// Calculate the distance between current data layout and origin data layout
std::vector<int32_t> output_permutation;
for (auto identity_data : identity_permutation) {
int32_t index = std::distance(input_permutation.begin(),
std::find(input_permutation.begin(),
input_permutation.end(),
identity_data));
output_permutation.push_back(index);
}
TransposeOperand(output_operand, output_permutation);
SetPermutation(output_operand, output_permutation);
} else {
// Skip NCHW2NHWC conversion
SetPermutation(output_operand,
IdentityPermutation(output_dimensions_count));
}
// Calculate the distance between current data layout and hchw data layout
std::vector<int32_t> output_permutation;
for (auto nchw_data : origin_data_layout) {
int32_t index = std::distance(
input_permutation.begin(),
std::find(
input_permutation.begin(), input_permutation.end(), nchw_data));
output_permutation.push_back(index);
}
TransposeOperand(output_operand, output_permutation);
SetPermutation(output_operand, output_permutation);
}

void NCHW2NHWCDataLayoutConverter::ConvertSplit(core::Operation* operation) {
Expand Down

0 comments on commit 893de00

Please sign in to comment.