diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_understand_sentiment_lstm_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_understand_sentiment_lstm_dist.py new file mode 100644 index 00000000000000..639d1a7affc090 --- /dev/null +++ b/python/paddle/v2/fluid/tests/book_distribute/notest_understand_sentiment_lstm_dist.py @@ -0,0 +1,187 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import os +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +from paddle.v2.fluid.layer_helper import LayerHelper + +BATCH_SIZE = 100 +PASS_NUM = 5 + + +def lstm(x, c_pre_init, hidden_dim, forget_bias=None): + """ + This function create an operator for the LSTM cell + that can be used inside an RNN. + """ + helper = LayerHelper('lstm_unit', **locals()) + rnn = fluid.layers.StaticRNN() + with rnn.step(): + c_pre = rnn.memory(init=c_pre_init) + x_t = rnn.step_input(x) + + before_fc = fluid.layers.concat(input=[x_t, c_pre], axis=1) + after_fc = fluid.layers.fc(input=before_fc, size=hidden_dim * 4) + + dtype = x.dtype + c = helper.create_tmp_variable(dtype) + h = helper.create_tmp_variable(dtype) + + helper.append_op( + type='lstm_unit', + inputs={"X": after_fc, + "C_prev": c_pre}, + outputs={"C": c, + "H": h}, + attrs={"forget_bias": forget_bias}) + + rnn.update_memory(c_pre, c) + rnn.output(h) + + return rnn() + + +def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50): + data = fluid.layers.data( + name="words", + shape=[seq_len * batch_size, 1], + append_batch_size=False, + dtype="int64", + lod_level=1) + label = fluid.layers.data( + name="label", + shape=[batch_size, 1], + append_batch_size=False, + dtype="int64") + + emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) + emb = fluid.layers.reshape(x=emb, shape=[batch_size, seq_len, emb_dim]) + emb = fluid.layers.transpose(x=emb, perm=[1, 0, 2]) + + c_pre_init = fluid.layers.fill_constant( + dtype=emb.dtype, shape=[batch_size, emb_dim], value=0.0) + c_pre_init.stop_gradient = False + layer_1_out = lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim) + layer_1_out = fluid.layers.transpose(x=layer_1_out, perm=[1, 0, 2]) + + prediction = fluid.layers.fc(input=layer_1_out, + size=class_dim, + act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + + avg_cost = fluid.layers.mean(x=cost) + adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002) + optimize_ops, params_grads = adam_optimizer.minimize(avg_cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, optimize_ops, params_grads + + +def to_lodtensor(data, place): + seq_lens = [len(seq) for seq in data] + cur_len = 0 + lod = [cur_len] + for l in seq_lens: + cur_len += l + lod.append(cur_len) + flattened_data = np.concatenate(data, axis=0).astype("int64") + flattened_data = flattened_data.reshape([len(flattened_data), 1]) + res = fluid.LoDTensor() + res.set(flattened_data, place) + res.set_lod([lod]) + return res + + +def chop_data(data, chop_len=80, batch_size=50): + data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len] + return data[:batch_size] + + +def prepare_feed_data(data, place): + tensor_words = to_lodtensor(map(lambda x: x[0], data), place) + + label = np.array(map(lambda x: x[1], data)).astype("int64") + label = label.reshape([len(label), 1]) + tensor_label = fluid.LoDTensor() + tensor_label.set(label, place) + return tensor_words, tensor_label + + +def main(): + + word_dict = paddle.dataset.imdb.word_dict() + dict_dim = len(word_dict) + class_dim = 2 + + cost, acc, optimize_ops, params_grads = lstm_net( + dict_dim=dict_dim, class_dim=class_dim) + + train_data = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10), + batch_size=BATCH_SIZE) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + t = fluid.DistributeTranspiler() + + # all parameter server endpoints list for spliting parameters + pserver_endpoints = os.getenv("PSERVERS") + # server endpoint for current node + current_endpoint = os.getenv("SERVER_ENDPOINT") + # run as trainer or parameter server + training_role = os.getenv( + "TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver + t.transpile( + optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) + + if training_role == "PSERVER": + if not current_endpoint: + print("need env SERVER_ENDPOINT") + exit(1) + pserver_prog = t.get_pserver_program(current_endpoint) + pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) + exe.run(pserver_startup) + exe.run(pserver_prog) + elif training_role == "TRAINER": + exe.run(fluid.default_startup_program()) + trainer_prog = t.get_trainer_program() + + for pass_id in xrange(PASS_NUM): + for data in train_data(): + chopped_data = chop_data(data) + tensor_words, tensor_label = prepare_feed_data(chopped_data, + place) + + outs = exe.run( + trainer_prog, + feed={"words": tensor_words, + "label": tensor_label}, + fetch_list=[cost, acc]) + + cost_val = np.array(outs[0]) + acc_val = np.array(outs[1]) + + print("cost=" + str(cost_val) + " acc=" + str(acc_val)) + if acc_val > 0.7: + exit(0) + else: + print("environment var TRAINER_ROLE should be TRAINER os PSERVER") + + +if __name__ == '__main__': + main()