-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Unsure why CI fails, I did not touch that part: |
66f6335
to
1bca54a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
src/operator/tensor/diag_op-inl.h
Outdated
struct DiagParam : public dmlc::Parameter<DiagParam> { | ||
int32_t k; | ||
DMLC_DECLARE_PARAMETER(DiagParam) { | ||
DMLC_DECLARE_FIELD(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indentation looks weird, I think the lint suggest 2 spaces or 4 spaces
r = mx.nd.diag(a) | ||
|
||
for i in range(r.shape[0]): | ||
assert r[i] == a[i][i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use numpy for consistency check. It's more stable than manually setting the desired values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorporated the changes.
b764ef9
to
d2ef62e
Compare
Any idea why it does not pass CI? |
@ifeherva Try locally simulate the CI with the docker build script. Currently we can get nothing from the log. |
Ran on my p3.2x.large
without issues |
871f510
to
f54b7ac
Compare
Operator is now passing all tests, I consider this work finished. |
@eric-haibin-lin seems like there should be a sparse version for this too. Thoughts? |
LGTM now. @eric-haibin-lin Can you do a quick review? |
@@ -1302,6 +1302,14 @@ def flip(self, *args, **kwargs): | |||
""" | |||
return op.flip(self, *args, **kwargs) | |||
|
|||
def diag(self, k=0, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update /~https://github.com/apache/incubator-mxnet/blob/master/docs/api/python/ndarray/ndarray.md and symbol/symbol.md (probably add it to the section of array creation routines like https://docs.scipy.org/doc/numpy/reference/routines.array-creation.html)
@reminisce maybe we should also mention adding documentation in the operator tutorial ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, added it.
@ifeherva thanks for the contribution! Good work. @szha I think |
@eric-haibin-lin Let's have the sparse feature in a separate PR. |
I am happy to do the sparse array support in a future PR. |
Done: 2d input forward pass Missing: 1d input forward all backward
Fixed small typos as well
Finished function documentation Added unit tests
Issues were extra white spaces and include order
Added matching test case
a57b8d5
to
702b416
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending one minor comment for doc
@@ -131,6 +131,7 @@ The `ndarray` package provides several classes: | |||
NDArray.flatten | |||
NDArray.expand_dims | |||
NDArray.split | |||
NDArray.diag | |||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't make it clear - there're two places to add per file. For ndarray.md
, One is NDArray.diag (fluent method) and the other is (ndarray.)diag at line 360.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
* Added np.diag as mxnet operator, WIP Done: 2d input forward pass Missing: 1d input forward all backward * Added a simple gradient transfer backwards operator for diag Fixed small typos as well * Finished backward operation * Added full support for k * Finished added the 1D case to the diag operator Finished function documentation Added unit tests * Fixed cpplinter errors in the diag operator Issues were extra white spaces and include order * Fixed indentation in diag_op-inl.h * Changed diag operator tests to use np.diag() as comparison * Fixed kernel bug in gpu diag operator * Replaced the min operator with an inline if statement. * Added diag to ndarray and symbol * Replaced the type of parameter k from int32 to nnvm::dim * Added default argument to k in ndarray and symbol * Fixed ndarray and symbol diag calls * Fixed the optional k parameter * Fixed cpp linting error * Changed test data datatype to float32 * K values resulting into 0-sized diagonals will now throw an exception. Added matching test case * Fixed unittest * Added diag to NDArray and Symbol api doc * Added missing api doc
* Added np.diag as mxnet operator, WIP Done: 2d input forward pass Missing: 1d input forward all backward * Added a simple gradient transfer backwards operator for diag Fixed small typos as well * Finished backward operation * Added full support for k * Finished added the 1D case to the diag operator Finished function documentation Added unit tests * Fixed cpplinter errors in the diag operator Issues were extra white spaces and include order * Fixed indentation in diag_op-inl.h * Changed diag operator tests to use np.diag() as comparison * Fixed kernel bug in gpu diag operator * Replaced the min operator with an inline if statement. * Added diag to ndarray and symbol * Replaced the type of parameter k from int32 to nnvm::dim * Added default argument to k in ndarray and symbol * Fixed ndarray and symbol diag calls * Fixed the optional k parameter * Fixed cpp linting error * Changed test data datatype to float32 * K values resulting into 0-sized diagonals will now throw an exception. Added matching test case * Fixed unittest * Added diag to NDArray and Symbol api doc * Added missing api doc
Description
Added a new tensor operator called diag() replicating numpy.diag().
The only difference to numpy.diag() is that for invalid k numpy diag returns an empty array while mxnet is going to LOG FATAL.
Checklist
Essentials
Comments