diff --git a/mmdnn/conversion/mxnet/mxnet_parser.py b/mmdnn/conversion/mxnet/mxnet_parser.py index 189a21fc..e08d74a2 100644 --- a/mmdnn/conversion/mxnet/mxnet_parser.py +++ b/mmdnn/conversion/mxnet/mxnet_parser.py @@ -809,7 +809,7 @@ def rename_Embedding(self, source_node): self.set_output_shape(source_node, IR_node) - # IR only support elu from {'elu', 'leaky', 'prelu', 'rrelu'} + # IR only support elu and prelu from {'elu', 'leaky', 'prelu', 'rrelu'} def rename_LeakyReLU(self, source_node): # judge whether meaningful assert "attr" @@ -817,14 +817,19 @@ def rename_LeakyReLU(self, source_node): layer_attr = self._get_layer_attr(source_node) if "act_type" in layer_attr: - if not layer_attr["act_type"] == "elu": + if not layer_attr["act_type"] == "elu" and not layer_attr["act_type"] == "prelu": print("Warning: Activation Type %s is not supported yet." % layer_attr["act_type"]) # return IR_node = self.IR_graph.node.add() # name, op - self._copy_and_reop(source_node, IR_node, "Elu") + if layer_attr['act_type'] == 'prelu': + self._copy_and_reop(source_node, IR_node, "PRelu") + # gamma + self.set_weight(source_node.name, "gamma", self.weight_data.get(source_node.name + "_gamma").asnumpy()) + else: # All other cases set to 'Elu' + self._copy_and_reop(source_node, IR_node, "Elu") # input edge self.convert_inedge(source_node, IR_node)