Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
backport #19393 to v1.x (#19396)
Browse files Browse the repository at this point in the history
* backport #19393 to v1.x

* other commits

* namespace fix

* added copying lib_api.cc into pip wheel for building extensions

* fixed setup.py
  • Loading branch information
samskalicky authored Oct 23, 2020
1 parent 9222fa1 commit fdc6022
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 15 deletions.
12 changes: 6 additions & 6 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,25 +912,25 @@ class Registry {

/*! \brief declare a variable with custom name */
#define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _
#define MX_REGISTER_DEF_(Name) CustomOp MX_REGISTER_NAME_(Name)
#define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name)

#define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _
#define MX_REGISTER_PROP_DEF_(Name) CustomPartitioner MX_REGISTER_PROP_NAME_(Name)
#define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name)

#define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _
#define MX_REGISTER_PASS_DEF_(Name) CustomPass MX_REGISTER_PASS_NAME_(Name)
#define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name)

/*! \brief assign a var to a value */
#define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \
Registry<CustomOp>::get()->add(MX_TOSTRING(Name))
mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name))

#define REGISTER_PARTITIONER(Name) \
MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \
Registry<CustomPartitioner>::get()->add(MX_TOSTRING(Name))
mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name))

#define REGISTER_PASS(Name) \
MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
Registry<CustomPass>::get()->add(MX_TOSTRING(Name))
mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name))

/* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */

Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def _build_cache(self, *args):
'added to the parameter dicts.\n'
'Please check the backend.')

param = Parameter(name)
param = Parameter(name, dtype=param_data.dtype)
param._load_init(param_data, args[0].context)
pair = (False, param)

Expand Down
9 changes: 6 additions & 3 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,13 @@ def _reduce(self):
ctx = context.cpu()
if self._stype == 'default':
block = self.list_data()
if is_np_array():
data = sum([w.copyto(ctx) for w in block]) / len(block)
if len(block) > 1:
if is_np_array():
data = sum([w.copyto(ctx) for w in block]) / len(block)
else:
data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block)
else:
data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block)
data = self.data().copyto(ctx)
else:
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
Expand Down
26 changes: 21 additions & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,13 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
auto in_first = in_shape->begin();
auto in_last = in_first + in_shape->size() - extra_inputs;
mxnet::ShapeVector *sg_in_shapes = new mxnet::ShapeVector(in_first, in_last);
return mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape);
bool res = mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape);

// assign modified input shapes to ShapeVector
for (unsigned i = 0; i < sg_in_shapes->size(); ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, i, sg_in_shapes->at(i));
}
return res;
};

// lambda function to call infer type
Expand Down Expand Up @@ -933,7 +939,12 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
auto in_last = in_first + in_type->size() - extra_inputs;
std::vector<int> *sg_in_types = new std::vector<int>(in_first, in_last);

return mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type);
bool res = mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type);
// copy and assign modified input types
for (size_t i = 0; i < sg_in_types->size(); i++) {
TYPE_ASSIGN_CHECK(*in_type, i, sg_in_types->at(i));
}
return res;
};

// lambda function to convert from external mutate_inputs to internal MXNet types
Expand Down Expand Up @@ -1033,8 +1044,13 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
auto in_last = in_first + in_stypes->size() - extra_inputs;
std::vector<int> *sg_in_stypes = new std::vector<int>(in_first, in_last);

return mxnet::op::DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
sg_in_stypes, out_stypes);
bool res = mxnet::op::DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
sg_in_stypes, out_stypes);
// copy and assign modified input storage types
for (size_t i = 0; i < sg_in_stypes->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, sg_in_stypes->at(i));
}
return res;
};

// FGradient register lambda
Expand Down Expand Up @@ -1416,7 +1432,7 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
// this temp workspace holds memory allocated by custom library via OpResource
auto ndarray_alloc = [&](const mxnet::TShape &shape, Context ctx, int dtype,
std::string name, bool isArg) {
NDArray* arr = new NDArray(shape, ctx, dtype);
NDArray* arr = new NDArray(shape, ctx, false, dtype);
if (isArg) {
new_args.push_back(arr);
new_arg_names.push_back(name);
Expand Down
4 changes: 4 additions & 0 deletions tools/pip/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def has_ext_modules(self):
shutil.copytree(os.path.join(CURRENT_DIR, 'mxnet-build/3rdparty/tvm/nnvm/include/nnvm'),
os.path.join(CURRENT_DIR, 'mxnet/include/nnvm'))

# copy cc file for mxnet extensions
shutil.copy(os.path.join(CURRENT_DIR, 'mxnet-build/src/lib_api.cc'),
os.path.join(CURRENT_DIR, 'mxnet/src'))

package_name = 'mxnet'

variant = os.environ['mxnet_variant'].upper()
Expand Down

0 comments on commit fdc6022

Please sign in to comment.