Skip to content

Commit

Permalink
Merge pull request #3 from qingqing01/api_loss
Browse files Browse the repository at this point in the history
Refine Loss in Model
  • Loading branch information
qingqing01 authored Mar 16, 2020
2 parents 1a2d3b5 + 14a5737 commit 358f785
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 180 deletions.
27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
- repo: /~https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: /~https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.(md|yml)$
- id: trailing-whitespace
files: \.(md|yml)$
- repo: /~https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.(md|yml)$
- id: remove-crlf
files: \.(md|yml)$
- id: forbid-tabs
files: \.(md|yml)$
- id: remove-tabs
files: \.(md|yml)$
47 changes: 27 additions & 20 deletions mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear

from model import Model, CrossEntropy
from model import Model, CrossEntropy, Input


class SimpleImgConvPool(fluid.dygraph.Layer):
Expand Down Expand Up @@ -78,7 +78,6 @@ def forward(self, inputs):
class MNIST(Model):
def __init__(self):
super(MNIST, self).__init__()

self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 20, 5, 2, 2, act="relu")

Expand All @@ -88,12 +87,13 @@ def __init__(self):
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(800,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
self._fc = Linear(
800,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")

def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
Expand Down Expand Up @@ -137,13 +137,15 @@ def null_guard():
paddle.batch(paddle.dataset.mnist.test(),
batch_size=FLAGS.batch_size, drop_last=True), 1, 1)

device_ids = list(range(FLAGS.num_devices))

with guard:
model = MNIST()
optim = Momentum(learning_rate=FLAGS.lr, momentum=.9,
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy())
optim = Momentum(
learning_rate=FLAGS.lr,
momentum=.9,
parameter_list=model.parameters())
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim, CrossEntropy(), inputs, labels)
if FLAGS.resume is not None:
model.load(FLAGS.resume)

Expand All @@ -154,8 +156,7 @@ def null_guard():
val_acc = 0.0
print("======== train epoch {} ========".format(e))
for idx, batch in enumerate(train_loader()):
outputs, losses = model.train(batch[0], batch[1], device='gpu',
device_ids=device_ids)
outputs, losses = model.train(batch[0], batch[1])

acc = accuracy(outputs[0], batch[1])[0]
train_loss += np.sum(losses)
Expand All @@ -166,8 +167,7 @@ def null_guard():

print("======== eval epoch {} ========".format(e))
for idx, batch in enumerate(val_loader()):
outputs, losses = model.eval(batch[0], batch[1], device='gpu',
device_ids=device_ids)
outputs, losses = model.eval(batch[0], batch[1])

acc = accuracy(outputs[0], batch[1])[0]
val_loss += np.sum(losses)
Expand All @@ -185,14 +185,21 @@ def null_guard():
parser.add_argument(
"-e", "--epoch", default=100, type=int, help="number of epoch")
parser.add_argument(
'--lr', '--learning-rate', default=1e-3, type=float, metavar='LR',
'--lr',
'--learning-rate',
default=1e-3,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=128, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=4, type=int, help="number of devices")
"-n", "--num_devices", default=1, type=int, help="number of devices")
parser.add_argument(
"-r", "--resume", default=None, type=str,
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
main()
Loading

0 comments on commit 358f785

Please sign in to comment.