-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathwave_transformer_3.py
166 lines (147 loc) · 5.57 KB
/
wave_transformer_3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from torch.nn import Module
from torch import Tensor
from typing import Optional, Union, Tuple, List
from torch import Tensor, zeros, cat as torch_cat
from torch.nn import Module, Linear, Softmax, Embedding
from torch.nn.functional import softmax
from modules import WaveBlock, WaveNetEncoder3
from modules.transformer import Transformer
from modules.transformer_block import TransformerBlock
from modules.positional_encoding import PositionalEncoding
import torch
import torch.nn.functional as F
from modules.decode_utils import greedy_decode, topk_sampling
from modules.beam import beam_decode
import gc
__author__ = 'An Tran'
__docformat__ = 'reStructuredText'
__all__ = ['WaveTransformer3']
class WaveTransformer3(Module):
"""
WaveTransformer full model with only E_temp branch
(denoted as WT_temp in paper)
"""
def __init__(self,
in_channels_encoder: int,
out_channels_encoder: List,
kernel_size_encoder: int,
dilation_rates_encoder: List,
last_dim_encoder: int,
num_layers_decoder: int,
num_heads_decoder: int,
n_features_decoder: int,
n_hidden_decoder: int,
nb_classes: int,
dropout_decoder: float,
beam_size: int,
) \
-> None:
"""WaveTransformer3 model.
:param in_channels_encoder: Input channels.
:type in_channels_encoder: int
:param out_channels_encoder: Output channels for the wave blocks
:type out_channels_encoder: List
:param kernel_size_encoder: Kernel shape/size for the wave blocks
:type kernel_size_encoder: List
:param dilation_rates_encoder: Dilation factors for the wave blocks
:type dilation_rates_encoder: List
:param last_dim_encoder: Output channels for Linear layer
:type last_dim_encoder: int
:param num_layers_decoder: Number of transformer blocks
:type num_layers_decoder: int
:param num_heads_decoder: Number of attention heads in each MHA
:type num_heads_decoder: int
:param n_features_decoder: number of features for transformer
:type n_features_decoder: int
:param n_hidden_decoder: hidden dimension of transformer
:type n_hidden_decoder: int
:param nb_classes: vocabulary size
:type nb_classes: int
:param dropout_decoder: dropout rate in decoder
:type dropout_decoder: float
:param beam_size: beam size (<1: greedy, >1: beam search)
:type beam_size: int
"""
super(WaveTransformer3, self).__init__()
self.max_length: int = 22
self.nb_classes: int = nb_classes
self.beam_size = beam_size
self.encoder: Module = WaveNetEncoder3(
in_channels=in_channels_encoder,
out_channels=out_channels_encoder,
kernel_size=kernel_size_encoder,
dilation_rates=dilation_rates_encoder,
last_dim=last_dim_encoder)
self.sublayer_decoder: Module = TransformerBlock(
n_features=n_features_decoder,
n_hidden=n_hidden_decoder,
num_heads=num_heads_decoder,
nb_classes=nb_classes,
dropout_p=dropout_decoder
)
self.decoder: Module = Transformer(
layer=self.sublayer_decoder,
num_layers=num_layers_decoder,
nb_classes=nb_classes,
n_features=n_features_decoder,
dropout_p=dropout_decoder)
self.embeddings: Embedding = Embedding(
num_embeddings=nb_classes,
embedding_dim=n_features_decoder)
self.classifier: Linear = Linear(
in_features=n_features_decoder,
out_features=nb_classes)
def forward(self, x, y):
if y is None:
return self._inference(x)
else:
return self._training_pass(x, y)
def _training_pass(self,
x: Tensor,
y: Tensor,
) \
-> Tensor:
"""Forward pass of the baseline method.
:param x: Input features.
:type x: torch.Tensor
:return: Predicted values.
:rtype: torch.Tensor
"""
torch.cuda.empty_cache()
gc.collect()
b_size, max_len = y.size()
device = y.device
y = y.permute(1, 0)[:-1]
encoder_output: Tensor = self.encoder(x)
encoder_output = encoder_output.permute(1, 0, 2)
word_embeddings: Tensor = self.embeddings(y)
decoder_output: Tensor = self.decoder(
word_embeddings,
encoder_output,
attention_mask=None
)
out: Tensor = self.classifier(decoder_output)
return out
def _inference(self, x):
torch.cuda.empty_cache()
gc.collect()
eos_token = 9
if self.beam_size > 1:
return beam_decode(x,
self.encoder,
self.decoder,
self.embeddings,
self.classifier,
self.beam_size,
1)
else:
return greedy_decode(x,
self.encoder,
self.decoder,
self.embeddings,
self.classifier,
self.max_length,
)
# EOF