Skip to content

Commit

Permalink
Fix an incorrect usage of is_local argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinghai-sun committed Aug 14, 2017
1 parent dd92a02 commit 7dfcdb0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 5 additions & 1 deletion deep_speech_2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def train(self,
gradient_clipping,
num_passes,
output_model_dir,
is_local=True,
num_iterations_print=100):
"""Train the model.
Expand All @@ -65,6 +66,8 @@ def train(self,
:param num_iterations_print: Number of training iterations for printing
a training loss.
:type rnn_iteratons_print: int
:param is_local: Set to False if running with pserver with multi-nodes.
:type is_local: bool
:param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring
"""
Expand Down Expand Up @@ -117,7 +120,8 @@ def event_handler(event):
reader=train_batch_reader,
event_handler=event_handler,
num_passes=num_passes,
feeding=feeding_dict)
feeding=feeding_dict,
is_local=is_local)

def infer_loss_batch(self, infer_data):
"""Model inference. Infer the ctc loss for a batch of speech
Expand Down
8 changes: 3 additions & 5 deletions deep_speech_2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,13 @@ def train():
gradient_clipping=400,
num_passes=args.num_passes,
num_iterations_print=args.num_iterations_print,
output_model_dir=args.output_model_dir)
output_model_dir=args.output_model_dir,
is_local=args.is_local)


def main():
utils.print_arguments(args)
paddle.init(
use_gpu=args.use_gpu,
trainer_count=args.trainer_count,
is_local=args.is_local)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train()


Expand Down

0 comments on commit 7dfcdb0

Please sign in to comment.