Skip to content

Commit

Permalink
fix error in dropout of hybrid_model
Browse files Browse the repository at this point in the history
  • Loading branch information
heavyrain-lzy committed Oct 18, 2023
1 parent af0d26a commit e8247b5
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def forward(self, tgt, memory=None, tgt_mask=None, use_cache=False, cache=None):

with get_rng_state_tracker().rng_state(current_seed):
if not self.use_fused_dropout_add:
tgt = residual + self.linear2(F.gelu(self.linear1(tgt), approximate=True))
tgt = residual + self.dropout2(self.linear2(F.gelu(self.linear1(tgt), approximate=True)))
else:
tgt = self.fused_dropout_add2(self.linear2(F.gelu(self.linear1(tgt), approximate=True)), residual)

Expand Down

0 comments on commit e8247b5

Please sign in to comment.