Skip to content

Commit

Permalink
Add api doc and update unittest. (PaddlePaddle#43)
Browse files Browse the repository at this point in the history
* Add doc strings.
* Update overlap_add op unittest
  • Loading branch information
KPatr1ck authored Sep 15, 2021
1 parent d2eebba commit c0289d1
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 26 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/operators/overlap_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class OverlapAddOp : public framework::OperatorWithKernel {
"Attribute(hop_length) of OverlapAddOp should be greater "
"than 0, but got %s.",
hop_length));

PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1), true,
platform::errors::InvalidArgument(
Expand All @@ -68,6 +69,13 @@ class OverlapAddOp : public framework::OperatorWithKernel {
end_axis = x_rank - 3;
}

PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));

const int seq_length = (n_frames - 1) * hop_length + frame_length;

// It won't go into for loop when x_rank == 2U.
Expand Down
24 changes: 12 additions & 12 deletions python/paddle/fluid/tests/unittests/test_overlap_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def setUp(self):
self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)}

def initTestCase(self):
input_shape = (150, 30)
input_shape = (50, 3)
input_type = 'float64'
attrs = {
'hop_length': 20,
'hop_length': 4,
'axis': -1,
}
return input_shape, input_type, attrs
Expand All @@ -100,54 +100,54 @@ def test_check_grad_normal(self):

class TestCase1(TestOverlapAddOp):
def initTestCase(self):
input_shape = (30, 150)
input_shape = (3, 50)
input_type = 'float64'
attrs = {
'hop_length': 15,
'hop_length': 4,
'axis': 0,
}
return input_shape, input_type, attrs


class TestCase2(TestOverlapAddOp):
def initTestCase(self):
input_shape = (2, 250, 10)
input_shape = (2, 40, 5)
input_type = 'float64'
attrs = {
'hop_length': 50,
'hop_length': 10,
'axis': -1,
}
return input_shape, input_type, attrs


class TestCase3(TestOverlapAddOp):
def initTestCase(self):
input_shape = (10, 250, 2)
input_shape = (5, 40, 2)
input_type = 'float64'
attrs = {
'hop_length': 30,
'hop_length': 10,
'axis': 0,
}
return input_shape, input_type, attrs


class TestCase4(TestOverlapAddOp):
def initTestCase(self):
input_shape = (3, 5, 70, 20)
input_shape = (3, 5, 12, 8)
input_type = 'float64'
attrs = {
'hop_length': 27,
'hop_length': 5,
'axis': -1,
}
return input_shape, input_type, attrs


class TestCase5(TestOverlapAddOp):
def initTestCase(self):
input_shape = (20, 70, 5, 3)
input_shape = (8, 12, 5, 3)
input_type = 'float64'
attrs = {
'hop_length': 33,
'hop_length': 5,
'axis': 0,
}
return input_shape, input_type, attrs
Expand Down
Loading

0 comments on commit c0289d1

Please sign in to comment.