Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare the gradient consistency between GPU and CPU calculations #3476

Merged
merged 8 commits into from
Aug 17, 2017
3 changes: 2 additions & 1 deletion paddle/operators/sigmoid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/sigmoid_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SigmoidKernel : public framework::OpKernel {
auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();

Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
Y.device(place) = 1. / (1. + (-X).exp());
}
};

Expand Down
1 change: 1 addition & 0 deletions python/paddle/v2/framework/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
214 changes: 117 additions & 97 deletions python/paddle/v2/framework/tests/gradient_checker.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import unittest

import numpy
import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator

__all__ = ['get_numeric_gradient']


def create_op(op_type):
# TODO need to set attrs
kwargs = dict()
for in_name in Operator.get_op_input_names(op_type):
kwargs[in_name] = in_name
Expand Down Expand Up @@ -66,7 +68,6 @@ def get_numeric_gradient(op,
local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace(
))

# TODO(yuyang18): Only CPU is support now.
cpu_ctx = core.DeviceContext.create(core.CPUPlace())

def get_output():
Expand Down Expand Up @@ -109,12 +110,110 @@ def product(dim):


class GradientChecker(unittest.TestCase):
def assert_is_close(self, numeric_grads, scope, max_relative_error,
msg_prefix):
for name in numeric_grads:
b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
a = numeric_grads[name]
def __get_gradient(self, forward_op, backward_op, input_value, grad_names,
place):
"""Get the input gradients after running forward and backward operators
on the given places.

:param forward_op: forward operator
:type forward_op: Operator
:param backward_op: backward operator
:type backward_op: Operator
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:param grad_names: the names of returned input gradients.
:type input_value: a list of string
:param place: the device type.
:type place: CPUPlace or GPUPlace
:return: the input grdients of given grad_names.
:rtype: a list of numpy.array
"""
scope = core.Scope()
ctx = core.DeviceContext.create(place)

inputs = forward_op.inputs()
in_names = [item for k in inputs for item in inputs[k]]
outputs = forward_op.outputs()
out_names = [item for k in outputs for item in outputs[k]]

# create input var and set value
for name, value in input_value.iteritems():
if name not in in_names:
raise ValueError(name + "does not exist in Op's inputs.")
var = scope.new_var(name).get_tensor()
var.set_dims(value.shape)
var.set(value, place)

# run forward op
for out_name in out_names:
scope.new_var(out_name)
forward_op.infer_shape(scope)
forward_op.run(scope, ctx)

# set output var's shape
# set output grad to ones
for name in out_names:
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(name)).get_tensor()
grad_tensor.set_dims(out_tensor.shape())
data = numpy.ones(out_tensor.shape(), dtype=numpy.float32)
grad_tensor.set(data, place)

# run backward op
for name in backward_op.outputs():
scope.new_var(name)
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)

outs = [
numpy.array(scope.find_var(name).get_tensor())
for name in grad_names
]
return outs

def compare_grad(self, forward_op, input_value):
""" Compare the input gradients between CPU and GPU for the given forward
operator.

:param forward_op: forward operator
:type forward_op: Operator
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:raises: AssertionError, there is different gradient value.
"""
backward_op = core.Operator.backward(forward_op, set())
# return if not compile with GPU or not implementing GPU kernel
if not (core.is_compile_gpu() and backward_op.support_gpu()):
return

outputs = backward_op.outputs()
out_names = [item for k in outputs for item in outputs[k]]
cpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
out_names, core.CPUPlace())
gpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
out_names, core.GPUPlace(0))

for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads,
out_names):
self.assertTrue(
numpy.allclose(
c_grad, g_grad, atol=1e-4),
"output name: " + name + " has diff")

def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix):
"""Use relative error for the comparison.

:param numeric_grads: the numerical graidents.
:type numeric_grads: a list of numpy.array
:param analytic_grads: the analytical graidents.
:type analytic_grads: a list of numpy.array
:param name: the names of gradients, used to print for debug.
:type names: a list of string
:param msg_prefix: string info, used to print for debug.
:type msf_prefix: string
"""
for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
abs_a = numpy.abs(a)
# if abs_a is nearly zero, then use abs error for a, not relative
# error.
Expand Down Expand Up @@ -159,106 +258,27 @@ def check_grad(self,

inputs = forward_op.inputs()
in_names = [item for k in inputs for item in inputs[k]]
outputs = forward_op.outputs()
out_names = [item for k in outputs for item in outputs[k]]

for no_grad in no_grad_set:
if no_grad not in in_names:
raise ValueError("no_grad should be in in_names")

backward_op = core.Operator.backward(forward_op, no_grad_set)

bwd_outputs = backward_op.outputs()
bwd_out_names = [item for k in bwd_outputs for item in bwd_outputs[k]]

places = [core.CPUPlace()]
if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu():
places.append(core.GPUPlace(0))

numeric_grad = dict()
# get numeric gradient
for check_name in inputs_to_check:
numeric_grad[check_name] = \
get_numeric_gradient(forward_op, input_vars, output_name,
check_name)
# get numerical gradients
numeric_grads = [
get_numeric_gradient(forward_op, input_vars, output_name, name)
for name in inputs_to_check
]

# get operator gradient according to different device
check_names = [grad_var_name(name) for name in inputs_to_check]
for place in places:
scope = core.Scope()
ctx = core.DeviceContext.create(place)

# create input var and set value
for name, value in input_vars.iteritems():
if name not in in_names:
raise ValueError(name + " not in op.inputs_")
var = scope.new_var(name).get_tensor()
var.set_dims(value.shape)
var.set(value, place)

# create output var
for out_name in out_names:
scope.new_var(out_name).get_tensor()

# infer the shape of output var and compute/set value of output var
forward_op.infer_shape(scope)
forward_op.run(scope, ctx)

# create output grad var
# set shape as the output var
# set value of this grad to ones
for name in out_names:
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(name)).get_tensor()
grad_tensor.set_dims(out_tensor.shape())
data = 1.0 * numpy.ones(out_tensor.shape())
grad_tensor.set(data, place)

# create input grad var
for name in bwd_out_names:
scope.new_var(name).get_tensor()

# infer the shape of input gradient var and compute/set it's value
# with backward op
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)

self.assert_is_close(numeric_grad, scope, max_relative_error,
"Gradient Check On %s" % str(place))


if __name__ == '__main__':

class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self):
add_op = Operator('add_two', X="X", Y="Y", Out="Z")
x = numpy.random.random((10, 1)).astype("float32")
y = numpy.random.random((10, 1)).astype("float32")

arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2)

def test_softmax_op(self):
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - numpy.max(x)
exps = numpy.exp(shiftx)
return exps / numpy.sum(exps)

def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(Y.shape[0]):
d = numpy.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX

softmax_op = Operator("softmax", X="X", Y="Y")

X = numpy.random.random((2, 2)).astype("float32")
Y = numpy.apply_along_axis(stable_softmax, 1, X)
dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY)

arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)

unittest.main()
# get analytical gradients according to different device
analytic_grads = self.__get_gradient(forward_op, backward_op,
input_vars, check_names, place)
self.__assert_is_close(numeric_grads, analytic_grads, check_names,
max_relative_error,
"Gradient Check On %s" % str(place))
43 changes: 43 additions & 0 deletions python/paddle/v2/framework/tests/test_gradient_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest
import numpy
from paddle.v2.framework.op import Operator
from gradient_checker import GradientChecker
from gradient_checker import get_numeric_gradient


class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self):
add_op = Operator('add_two', X="X", Y="Y", Out="Z")
x = numpy.random.random((10, 1)).astype("float32")
y = numpy.random.random((10, 1)).astype("float32")

arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)

def test_softmax_op(self):
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - numpy.max(x)
exps = numpy.exp(shiftx)
return exps / numpy.sum(exps)

def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(Y.shape[0]):
d = numpy.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX

softmax_op = Operator("softmax", X="X", Y="Y")

X = numpy.random.random((2, 2)).astype("float32")
Y = numpy.apply_along_axis(stable_softmax, 1, X)
dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY)

arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)


if __name__ == '__main__':
unittest.main()
17 changes: 13 additions & 4 deletions python/paddle/v2/framework/tests/test_sigmoid_op.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import unittest
from op_test_util import OpTestMeta
import numpy as np
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op


class TestSigmoidOp(unittest.TestCase):
__metaclass__ = OpTestMeta

def setUp(self):
self.type = "sigmoid"
self.inputs = {'X': np.random.random((32, 100)).astype("float32")}
self.inputs = {'X': np.random.random((15, 31)).astype("float32")}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}


#class TestSigmoidGradOp(unittest.TestCase):
#TODO(qingqing) add unit test
class TestSigmoidGradOp(GradientChecker):
def test_grad(self):
op = create_op("sigmoid")
inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")}
# compare gpu and cpu results for backward op.
# this test will be skiped if only compiling CPU version.
self.compare_grad(op, inputs)
# check gradients
self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007)


if __name__ == '__main__':
unittest.main()