Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
QiJune committed Sep 11, 2017
1 parent 436fbb0 commit 477b23c
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions python/paddle/v2/framework/tests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create_op(scope, op_type, inputs, outputs, attrs):
kwargs[in_name] = []
if in_dup:
sub_in = inputs[in_name]
for sub_in_name, arr in sub_in:
for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name)
kwargs[in_name].append(sub_in_name)
else:
Expand All @@ -29,7 +29,7 @@ def create_op(scope, op_type, inputs, outputs, attrs):
kwargs[out_name] = []
if out_dup:
sub_in = outputs[out_name]
for sub_in_name, arr in sub_in:
for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name)
kwargs[out_name].append(sub_in_name)
else:
Expand All @@ -47,11 +47,11 @@ def set_input(scope, op, inputs, place):
if in_name in inputs:
if in_dup:
sub_in = inputs[in_name]
for sub_in_name, arr in sub_in:
for sub_in_name, sub_in_array in sub_in:
var = scope.find_var(sub_in_name)
tensor = var.get_tensor()
tensor.set_dims(arr.shape)
tensor.set(arr, place)
tensor.set_dims(sub_in_array.shape)
tensor.set(sub_in_array, place)
else:
var = scope.find_var(in_name)
tensor = var.get_tensor()
Expand All @@ -65,7 +65,7 @@ def set_output_grad(scope, op, outputs, place):
if out_name in outputs:
if out_dup:
sub_out = outputs[out_name]
for sub_out_name, arr in sub_out:
for sub_out_name, sub_out_grad in sub_out:
out_tensor = scope.find_var(sub_out_name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(
sub_out_name)).get_tensor()
Expand Down Expand Up @@ -169,9 +169,8 @@ class OpTest(unittest.TestCase):
def check_output_with_place(self, place):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs,
op_attrs)
if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
return
Expand Down Expand Up @@ -232,9 +231,8 @@ def check_grad(self,
max_relative_error=0.005):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs,
op_attrs)
if no_grad_set is None:
no_grad_set = set()
Expand Down

0 comments on commit 477b23c

Please sign in to comment.