From 4542870321c8c54277bc364e5f98731e8eeb2517 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Sun, 17 Mar 2019 19:49:57 -0700 Subject: [PATCH] Fix crashes on visualization (#14425) * Check for kernel in Pooling * Fix Leakyrelu visualization * Address review comments * Change all occurences to string format * Fix lint error --- python/mxnet/visualization.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index dd3a1df345d3..4101f749a583 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -205,7 +205,7 @@ def print_layer_summary(node, out_shape): print('=' * line_length) else: print('_' * line_length) - print('Total params: %s' % total_params) + print("Total params: {params}".format(params=total_params)) print('_' * line_length) def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={}, @@ -337,24 +337,33 @@ def looks_like_weight(name): label = node["name"] attr["fillcolor"] = cm[0] elif op == "Convolution": - label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attrs"]["kernel"])), - "x".join(_str2tuple(node["attrs"]["stride"])) - if "stride" in node["attrs"] else "1", - node["attrs"]["num_filter"]) + label = "Convolution\n{kernel}/{stride}, {filter}".format( + kernel="x".join(_str2tuple(node["attrs"]["kernel"])), + stride="x".join(_str2tuple(node["attrs"]["stride"])) + if "stride" in node["attrs"] else "1", + filter=node["attrs"]["num_filter"] + ) attr["fillcolor"] = cm[1] elif op == "FullyConnected": - label = r"FullyConnected\n%s" % node["attrs"]["num_hidden"] + label = "FullyConnected\n{hidden}".format(hidden=node["attrs"]["num_hidden"]) attr["fillcolor"] = cm[1] elif op == "BatchNorm": attr["fillcolor"] = cm[3] - elif op in ('Activation', 'LeakyReLU'): - label = r"%s\n%s" % (op, node["attrs"]["act_type"]) + elif op == 'Activation': + act_type = node["attrs"]["act_type"] + label = 'Activation\n{activation}'.format(activation=act_type) + attr["fillcolor"] = cm[2] + elif op == 'LeakyReLU': + attrs = node.get("attrs") + act_type = attrs.get("act_type", "Leaky") if attrs else "Leaky" + label = 'LeakyReLU\n{activation}'.format(activation=act_type) attr["fillcolor"] = cm[2] elif op == "Pooling": - label = r"Pooling\n%s, %s/%s" % (node["attrs"]["pool_type"], - "x".join(_str2tuple(node["attrs"]["kernel"])), - "x".join(_str2tuple(node["attrs"]["stride"])) - if "stride" in node["attrs"] else "1") + label = "Pooling\n{pooltype}, {kernel}/{stride}".format(pooltype=node["attrs"]["pool_type"], + kernel="x".join(_str2tuple(node["attrs"]["kernel"])) + if "kernel" in node["attrs"] else "[]", + stride="x".join(_str2tuple(node["attrs"]["stride"])) + if "stride" in node["attrs"] else "1") attr["fillcolor"] = cm[4] elif op in ("Concat", "Flatten", "Reshape"): attr["fillcolor"] = cm[5]