diff --git a/examples/sample.py b/examples/sample.py index 889bfcb..8db6847 100644 --- a/examples/sample.py +++ b/examples/sample.py @@ -100,7 +100,7 @@ def len_filter(example): bidirectional = True encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, bidirectional=bidirectional, variable_lengths=True) - decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else 1, + decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size, dropout_p=0.2, use_attention=True, bidirectional=bidirectional, eos_id=tgt.eos_id, sos_id=tgt.sos_id) seq2seq = Seq2seq(encoder, decoder) diff --git a/seq2seq/dataset/fields.py b/seq2seq/dataset/fields.py index 8ee01be..a844000 100644 --- a/seq2seq/dataset/fields.py +++ b/seq2seq/dataset/fields.py @@ -11,7 +11,7 @@ def __init__(self, **kwargs): if kwargs.get('batch_first') is False: logger.warning("Option batch_first has to be set to use pytorch-seq2seq. Changed to True.") kwargs['batch_first'] = True - if kwargs.get('batch_first') is False: + if kwargs.get('include_lengths') is False: logger.warning("Option include_lengths has to be set to use pytorch-seq2seq. Changed to True.") kwargs['include_lengths'] = True diff --git a/seq2seq/models/DecoderRNN.py b/seq2seq/models/DecoderRNN.py index b46e198..7915f1e 100644 --- a/seq2seq/models/DecoderRNN.py +++ b/seq2seq/models/DecoderRNN.py @@ -131,7 +131,7 @@ def decode(step, step_output, step_attn): eos_batches = symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() - update_idx = ((lengths > di) & eos_batches) != 0 + update_idx = ((lengths > step) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return symbols diff --git a/seq2seq/models/TopKDecoder.py b/seq2seq/models/TopKDecoder.py index 626d27c..ae0d465 100644 --- a/seq2seq/models/TopKDecoder.py +++ b/seq2seq/models/TopKDecoder.py @@ -9,7 +9,7 @@ def _inflate(tensor, times, dim): Args: tensor: A :class:`Tensor` to inflate times: number of repetitions - dimension: axis for inflation (default=0) + dim: axis for inflation (default=0) Returns: A :class:`Tensor` @@ -20,17 +20,16 @@ def _inflate(tensor, times, dim): 1 2 3 4 [torch.LongTensor of size 2x2] - >> decoder = TopKDecoder(nn.RNN(10, 20, 2), 3) - >> b = decoder._inflate(a, 1, dimension=1) + >> b = ._inflate(a, 2, dim=1) >> b - 1 1 2 2 - 3 3 4 4 + 1 2 1 2 + 3 4 3 4 [torch.LongTensor of size 2x4] - >> c = decoder._inflate(a, 1, dimension=0) + >> c = _inflate(a, 2, dim=0) >> c 1 2 - 1 2 3 4 + 1 2 3 4 [torch.LongTensor of size 4x2] diff --git a/seq2seq/models/attention.py b/seq2seq/models/attention.py index 376896f..0f06916 100644 --- a/seq2seq/models/attention.py +++ b/seq2seq/models/attention.py @@ -10,7 +10,7 @@ class Attention(nn.Module): .. math:: \begin{array}{ll} x = context*output \\ - attn = exp(x_i - max_i x_i) / sum_j exp(x_j - max_i x_i) \\ + attn = exp(x_i) / sum_j exp(x_j) \\ output = \tanh(w * (attn * context) + b * output) \end{array} diff --git a/seq2seq/trainer/supervised_trainer.py b/seq2seq/trainer/supervised_trainer.py index 57dae64..68c2711 100644 --- a/seq2seq/trainer/supervised_trainer.py +++ b/seq2seq/trainer/supervised_trainer.py @@ -75,7 +75,8 @@ def _train_epoches(self, data, model, n_epochs, start_epoch, start_step, device = None if torch.cuda.is_available() else -1 batch_iterator = torchtext.data.BucketIterator( dataset=data, batch_size=self.batch_size, - sort=True, sort_key=lambda x: len(x.src), + sort=False, sort_within_batch=True, + sort_key=lambda x: len(x.src), device=device, repeat=False) steps_per_epoch = len(batch_iterator) @@ -166,6 +167,7 @@ def train(self, model, data, num_epochs=5, resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) + defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults) start_epoch = resume_checkpoint.epoch diff --git a/seq2seq/util/checkpoint.py b/seq2seq/util/checkpoint.py index d0bf482..f28a401 100644 --- a/seq2seq/util/checkpoint.py +++ b/seq2seq/util/checkpoint.py @@ -91,9 +91,13 @@ def load(cls, path): Returns: checkpoint (Checkpoint): checkpoint object with fields copied from those stored on disk """ - print("Loading checkpoints from {}".format(path)) - resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME)) - model = torch.load(os.path.join(path, cls.MODEL_NAME)) + if torch.cuda.is_available(): + resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME)) + model = torch.load(os.path.join(path, cls.MODEL_NAME)) + else: + resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME), map_location=lambda storage, loc: storage) + model = torch.load(os.path.join(path, cls.MODEL_NAME), map_location=lambda storage, loc: storage) + model.flatten_parameters() # make RNN parameters contiguous with open(os.path.join(path, cls.INPUT_VOCAB_FILE), 'rb') as fin: input_vocab = dill.load(fin)