Skip to content

Commit

Permalink
0.1.5 + minor fixes (#106)
Browse files Browse the repository at this point in the history
* Modified parameter order of DecoderRNN.forward (#85)

* Updated TopKDecoder (#86)

* Fixed topk decoder.

* Use torchtext from pipy (#87)

* Use torchtext from pipe.

* Fixed torch text sorting order.

* attention is not required when only using teacher forcing in decoder (#90)

* attention is not required when only using teacher forcing in decoder

* Updated docs and version.

* Fixed code style.

* bugfix (#92)

Fixed field arguments validation.

* Removed `initial_lr` when resuming optimizer with scheduler. (#95)

* shuffle the training data (#97)

* 0.1.5 (#91)

* Modified parameter order of DecoderRNN.forward (#85)

* Updated TopKDecoder (#86)

* Fixed topk decoder.

* Use torchtext from pipy (#87)

* Use torchtext from pipe.

* Fixed torch text sorting order.

* attention is not required when only using teacher forcing in decoder (#90)

* attention is not required when only using teacher forcing in decoder

* Updated docs and version.

* Fixed code style.

* shuffle the training data

* fix example of inflate function in TopKDecoer.py (#98)

* fix example of inflate function in TopKDecoer.py

* Fix hidden_layer size for one-directional decoder (#99)

* Fix hidden_layer size for one-directional decoder

Hidden layer size of the decoder was given `hidden_size * 2 if bidirectional else 1`, resulting in a dimensionality error for non-bidirectional decoders.
Changed `1` to `hidden_size`.

* Adapt load to allow CPU loading of GPU models (#100)

* Adapt load to allow CPU loading of GPU models

Add storage parameter to torch.load to allow loading
models on a CPU that are trained on the GPU, depending
on availability of cuda.

* Fix wrong parameter use on DecoderRNN (#103)

* Fix wrong parameter use on DecoderRNN
  • Loading branch information
kylegao91 authored Dec 4, 2017
1 parent e8250fb commit 96c6033
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 15 deletions.
2 changes: 1 addition & 1 deletion examples/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def len_filter(example):
bidirectional = True
encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
bidirectional=bidirectional, variable_lengths=True)
decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else 1,
decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size,
dropout_p=0.2, use_attention=True, bidirectional=bidirectional,
eos_id=tgt.eos_id, sos_id=tgt.sos_id)
seq2seq = Seq2seq(encoder, decoder)
Expand Down
2 changes: 1 addition & 1 deletion seq2seq/dataset/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, **kwargs):
if kwargs.get('batch_first') is False:
logger.warning("Option batch_first has to be set to use pytorch-seq2seq. Changed to True.")
kwargs['batch_first'] = True
if kwargs.get('batch_first') is False:
if kwargs.get('include_lengths') is False:
logger.warning("Option include_lengths has to be set to use pytorch-seq2seq. Changed to True.")
kwargs['include_lengths'] = True

Expand Down
2 changes: 1 addition & 1 deletion seq2seq/models/DecoderRNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def decode(step, step_output, step_attn):
eos_batches = symbols.data.eq(self.eos_id)
if eos_batches.dim() > 0:
eos_batches = eos_batches.cpu().view(-1).numpy()
update_idx = ((lengths > di) & eos_batches) != 0
update_idx = ((lengths > step) & eos_batches) != 0
lengths[update_idx] = len(sequence_symbols)
return symbols

Expand Down
13 changes: 6 additions & 7 deletions seq2seq/models/TopKDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def _inflate(tensor, times, dim):
Args:
tensor: A :class:`Tensor` to inflate
times: number of repetitions
dimension: axis for inflation (default=0)
dim: axis for inflation (default=0)
Returns:
A :class:`Tensor`
Expand All @@ -20,17 +20,16 @@ def _inflate(tensor, times, dim):
1 2
3 4
[torch.LongTensor of size 2x2]
>> decoder = TopKDecoder(nn.RNN(10, 20, 2), 3)
>> b = decoder._inflate(a, 1, dimension=1)
>> b = ._inflate(a, 2, dim=1)
>> b
1 1 2 2
3 3 4 4
1 2 1 2
3 4 3 4
[torch.LongTensor of size 2x4]
>> c = decoder._inflate(a, 1, dimension=0)
>> c = _inflate(a, 2, dim=0)
>> c
1 2
1 2
3 4
1 2
3 4
[torch.LongTensor of size 4x2]
Expand Down
2 changes: 1 addition & 1 deletion seq2seq/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Attention(nn.Module):
.. math::
\begin{array}{ll}
x = context*output \\
attn = exp(x_i - max_i x_i) / sum_j exp(x_j - max_i x_i) \\
attn = exp(x_i) / sum_j exp(x_j) \\
output = \tanh(w * (attn * context) + b * output)
\end{array}
Expand Down
4 changes: 3 additions & 1 deletion seq2seq/trainer/supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def _train_epoches(self, data, model, n_epochs, start_epoch, start_step,
device = None if torch.cuda.is_available() else -1
batch_iterator = torchtext.data.BucketIterator(
dataset=data, batch_size=self.batch_size,
sort=True, sort_key=lambda x: len(x.src),
sort=False, sort_within_batch=True,
sort_key=lambda x: len(x.src),
device=device, repeat=False)

steps_per_epoch = len(batch_iterator)
Expand Down Expand Up @@ -166,6 +167,7 @@ def train(self, model, data, num_epochs=5,
resume_optim = self.optimizer.optimizer
defaults = resume_optim.param_groups[0]
defaults.pop('params', None)
defaults.pop('initial_lr', None)
self.optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults)

start_epoch = resume_checkpoint.epoch
Expand Down
10 changes: 7 additions & 3 deletions seq2seq/util/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def load(cls, path):
Returns:
checkpoint (Checkpoint): checkpoint object with fields copied from those stored on disk
"""
print("Loading checkpoints from {}".format(path))
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME))
model = torch.load(os.path.join(path, cls.MODEL_NAME))
if torch.cuda.is_available():
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME))
model = torch.load(os.path.join(path, cls.MODEL_NAME))
else:
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME), map_location=lambda storage, loc: storage)
model = torch.load(os.path.join(path, cls.MODEL_NAME), map_location=lambda storage, loc: storage)

model.flatten_parameters() # make RNN parameters contiguous
with open(os.path.join(path, cls.INPUT_VOCAB_FILE), 'rb') as fin:
input_vocab = dill.load(fin)
Expand Down

0 comments on commit 96c6033

Please sign in to comment.