Skip to content

Commit

Permalink
Merge pull request #249 from ranqiu92/mt_with_external_memory2
Browse files Browse the repository at this point in the history
fix a bug of external_memory.
  • Loading branch information
lcy-seso authored Sep 14, 2017
2 parents 848bb8a + 16075ce commit 8b5c739
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
11 changes: 5 additions & 6 deletions mt_with_external_memory/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
Expand All @@ -49,6 +51,7 @@ def __init__(self,
name,
mem_slot_size,
boot_layer,
initial_weight,
readonly=False,
enable_interpolation=True):
self.name = name
Expand All @@ -57,11 +60,7 @@ def __init__(self,
self.enable_interpolation = enable_interpolation
self.external_memory = paddle.layer.memory(
name=self.name, size=self.mem_slot_size, boot_layer=boot_layer)
# prepare a constant (zero) intializer for addressing weights
self.zero_addressing_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=boot_layer, size=1),
slope=0.0,
intercept=0.0)
self.initial_weight = initial_weight
# set memory to constant when readonly=True
if self.readonly:
self.updated_external_memory = paddle.layer.mixed(
Expand Down Expand Up @@ -111,7 +110,7 @@ def _interpolation(self, head_name, key_vector, addressing_weight):
last_addressing_weight = paddle.layer.memory(
name=self.name + "_addressing_weight_" + head_name,
size=1,
boot_layer=self.zero_addressing_init)
boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight],
Expand Down
12 changes: 11 additions & 1 deletion mt_with_external_memory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
bounded_memory_perturbation
],
act=paddle.activation.Linear())
bounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=bounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
unbounded_memory_init = source_context
unbounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=unbounded_memory_init, size=1),
slope=0.0,
intercept=0.0)

# prepare step function for reccurent group
def recurrent_decoder_step(cur_embedding):
Expand All @@ -136,12 +144,14 @@ def recurrent_decoder_step(cur_embedding):
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
initial_weight=bounded_memory_weight_init,
readonly=False,
enable_interpolation=True)
unbounded_memory = ExternalMemory(
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
initial_weight=unbounded_memory_weight_init,
readonly=True,
enable_interpolation=False)
# write bounded memory
Expand All @@ -154,7 +164,7 @@ def recurrent_decoder_step(cur_embedding):
size=size,
act=paddle.activation.Tanh(),
bias_attr=False)
# read unbounded memory (i.e. attention mechanism)
# read unbounded memory (i.e. attention mechanism)
context = unbounded_memory.read(key_for_unbounded_memory)
# gated recurrent unit
gru_inputs = paddle.layer.fc(
Expand Down

0 comments on commit 8b5c739

Please sign in to comment.