Skip to content

Commit

Permalink
Fix crashes on visualization (apache#14425)
Browse files Browse the repository at this point in the history
* Check for kernel in Pooling

* Fix Leakyrelu visualization

* Address review comments

* Change all occurences to string format

* Fix lint error
  • Loading branch information
vandanavk authored and haohuw committed Jun 23, 2019
1 parent 764dbb1 commit 7557323
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def print_layer_summary(node, out_shape):
print('=' * line_length)
else:
print('_' * line_length)
print('Total params: %s' % total_params)
print("Total params: {params}".format(params=total_params))
print('_' * line_length)

def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={},
Expand Down Expand Up @@ -337,24 +337,33 @@ def looks_like_weight(name):
label = node["name"]
attr["fillcolor"] = cm[0]
elif op == "Convolution":
label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attrs"]["kernel"])),
"x".join(_str2tuple(node["attrs"]["stride"]))
if "stride" in node["attrs"] else "1",
node["attrs"]["num_filter"])
label = "Convolution\n{kernel}/{stride}, {filter}".format(
kernel="x".join(_str2tuple(node["attrs"]["kernel"])),
stride="x".join(_str2tuple(node["attrs"]["stride"]))
if "stride" in node["attrs"] else "1",
filter=node["attrs"]["num_filter"]
)
attr["fillcolor"] = cm[1]
elif op == "FullyConnected":
label = r"FullyConnected\n%s" % node["attrs"]["num_hidden"]
label = "FullyConnected\n{hidden}".format(hidden=node["attrs"]["num_hidden"])
attr["fillcolor"] = cm[1]
elif op == "BatchNorm":
attr["fillcolor"] = cm[3]
elif op in ('Activation', 'LeakyReLU'):
label = r"%s\n%s" % (op, node["attrs"]["act_type"])
elif op == 'Activation':
act_type = node["attrs"]["act_type"]
label = 'Activation\n{activation}'.format(activation=act_type)
attr["fillcolor"] = cm[2]
elif op == 'LeakyReLU':
attrs = node.get("attrs")
act_type = attrs.get("act_type", "Leaky") if attrs else "Leaky"
label = 'LeakyReLU\n{activation}'.format(activation=act_type)
attr["fillcolor"] = cm[2]
elif op == "Pooling":
label = r"Pooling\n%s, %s/%s" % (node["attrs"]["pool_type"],
"x".join(_str2tuple(node["attrs"]["kernel"])),
"x".join(_str2tuple(node["attrs"]["stride"]))
if "stride" in node["attrs"] else "1")
label = "Pooling\n{pooltype}, {kernel}/{stride}".format(pooltype=node["attrs"]["pool_type"],
kernel="x".join(_str2tuple(node["attrs"]["kernel"]))
if "kernel" in node["attrs"] else "[]",
stride="x".join(_str2tuple(node["attrs"]["stride"]))
if "stride" in node["attrs"] else "1")
attr["fillcolor"] = cm[4]
elif op in ("Concat", "Flatten", "Reshape"):
attr["fillcolor"] = cm[5]
Expand Down

0 comments on commit 7557323

Please sign in to comment.