From 1fa4663962bb782fe949489c50efbe3fff81778c Mon Sep 17 00:00:00 2001 From: LiuChiaChi <709153940@qq.com> Date: Sat, 20 Feb 2021 02:03:46 +0000 Subject: [PATCH] fix seq2seq args bug, replace use_gpu with select_device --- examples/machine_translation/seq2seq/predict.py | 2 +- examples/machine_translation/seq2seq/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/machine_translation/seq2seq/predict.py b/examples/machine_translation/seq2seq/predict.py index 818dca16d5e911..0f724b66723fa3 100644 --- a/examples/machine_translation/seq2seq/predict.py +++ b/examples/machine_translation/seq2seq/predict.py @@ -41,7 +41,7 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): def do_predict(args): - device = paddle.set_device("gpu" if args.use_gpu else "cpu") + device = paddle.set_device(args.select_device) test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader( args) diff --git a/examples/machine_translation/seq2seq/train.py b/examples/machine_translation/seq2seq/train.py index 834fcf19d51c87..b15ebadc33733f 100644 --- a/examples/machine_translation/seq2seq/train.py +++ b/examples/machine_translation/seq2seq/train.py @@ -23,7 +23,7 @@ def do_train(args): - device = paddle.set_device("gpu" if args.use_gpu else "cpu") + device = paddle.set_device(args.select_device) # Define dataloader train_loader, eval_loader, src_vocab_size, tgt_vocab_size, eos_id = create_train_loader(