You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
However, Let us now compute output of each FC in above network (fc0_output, fc1_output,... fc4_output). What I observe is the if I do individual fc output calculation and sum it up it is not same result as running everything together.
constituent_fc0 = fully_connected_symbols[0]
print(constituent_fc0.get_internals().list_outputs())
mod_cons_fc0 = mx.mod.Module(symbol=constituent_fc0, data_names=['data_0'], label_names=None)
mod_cons_fc0.bind(for_training=False, data_shapes=[('data_0', data_shape)])
mod_cons_fc0.set_params(mod.get_params()[0], mod.get_params()[1])
mod_cons_fc0.forward(mx.io.DataBatch([mx.nd.ones(data_shape)]))
o1 = mod_cons_fc0.get_outputs()[0]
#and so on for fc1, fc2, fc3, fc4
#and then do
print(nd.add_n(o1, o2, o3, o4, o5))
Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Bug
add_n output mem overlap with input mem (due to FInplaceOption), but in Forward function : ElementwiseSumContainsDnsImpl( ), output mem was set_zero which makes input zero: Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr());
Problem:
With mxnet-mkl (1.4.0)
If number of input symbols > 4 and I perform add_n after a FC layer produces wrong results.
i.e.,
Minimum reproducible code below:
Run below code which is full network:
Output
However, Let us now compute output of each FC in above network (fc0_output, fc1_output,... fc4_output). What I observe is the if I do individual fc output calculation and sum it up it is not same result as running everything together.
@ZhennanQin @pengzhao-intel - Can you please help debug this issue?
Please Note:
The text was updated successfully, but these errors were encountered: