Skip to content

Commit

Permalink
Merge pull request #114 from tensorops/encoder_block_example
Browse files Browse the repository at this point in the history
[PR] Encoder-only classifier example resolves #113
  • Loading branch information
soran-ghaderi authored Jan 27, 2024
2 parents 80343bb + b601c6b commit 2fea0ac
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 22 deletions.
69 changes: 69 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,72 @@ tase.toml
Makefile
source/
/docs/source/*
/docs/build/html/.buildinfo
/docs/build/html/_static/basic.css
/docs/build/html/_static/debug.css
/docs/build/html/_static/doctools.js
/docs/build/html/_static/documentation_options.js
/docs/build/html/_static/file.png
/docs/build/html/_static/styles/furo.css
/docs/build/html/_static/styles/furo.css.map
/docs/build/html/_static/scripts/furo.js
/docs/build/html/_static/scripts/furo.js.LICENSE.txt
/docs/build/html/_static/scripts/furo.js.map
/docs/build/html/_static/styles/furo-extensions.css
/docs/build/html/_static/styles/furo-extensions.css.map
/docs/build/html/_static/scripts/furo-extensions.js
/docs/build/html/genindex.html
/docs/build/html/index.html
/docs/build/html/_sources/index.rst.txt
/docs/build/html/_static/language_data.js
/docs/build/html/_static/minus.png
/docs/build/html/modules.html
/docs/build/html/_sources/modules.rst.txt
/docs/build/html/objects.inv
/docs/build/html/_static/plus.png
/docs/build/html/py-modindex.html
/docs/build/html/_static/pygments.css
/docs/build/html/search.html
/docs/build/html/searchindex.js
/docs/build/html/_static/searchtools.js
/docs/build/html/_static/skeleton.css
/docs/build/html/_static/sphinx_highlight.js
/examples/temp.py
/docs/build/html/transformerx.data_loader.html
/docs/build/html/_sources/transformerx.data_loader.rst.txt
/docs/build/html/transformerx.html
/docs/build/html/transformerx.layers.addnorm.html
/docs/build/html/_sources/transformerx.layers.addnorm.rst.txt
/docs/build/html/transformerx.layers.dot_product_attention.html
/docs/build/html/_sources/transformerx.layers.dot_product_attention.rst.txt
/docs/build/html/transformerx.layers.html
/docs/build/html/transformerx.layers.masks.global_attention_mask.html
/docs/build/html/_sources/transformerx.layers.masks.global_attention_mask.rst.txt
/docs/build/html/transformerx.layers.masks.html
/docs/build/html/_sources/transformerx.layers.masks.rst.txt
/docs/build/html/transformerx.layers.multihead_attention.html
/docs/build/html/_sources/transformerx.layers.multihead_attention.rst.txt
/docs/build/html/transformerx.layers.positional_encoding.html
/docs/build/html/_sources/transformerx.layers.positional_encoding.rst.txt
/docs/build/html/transformerx.layers.positionwise_ffn.html
/docs/build/html/_sources/transformerx.layers.positionwise_ffn.rst.txt
/docs/build/html/_sources/transformerx.layers.rst.txt
/docs/build/html/transformerx.layers.transformer_decoder.html
/docs/build/html/_sources/transformerx.layers.transformer_decoder.rst.txt
/docs/build/html/transformerx.layers.transformer_decoder_block.html
/docs/build/html/_sources/transformerx.layers.transformer_decoder_block.rst.txt
/docs/build/html/transformerx.layers.transformer_encoder.html
/docs/build/html/_sources/transformerx.layers.transformer_encoder.rst.txt
/docs/build/html/transformerx.layers.transformer_encoder_block.html
/docs/build/html/_sources/transformerx.layers.transformer_encoder_block.rst.txt
/docs/build/html/_sources/transformerx.rst.txt
/docs/build/html/transformerx.training.base.html
/docs/build/html/_sources/transformerx.training.base.rst.txt
/docs/build/html/transformerx.training.html
/docs/build/html/_sources/transformerx.training.rst.txt
/docs/build/html/transformerx.txplot.html
/docs/build/html/transformerx.txplot.plot_pe.html
/docs/build/html/_sources/transformerx.txplot.plot_pe.rst.txt
/docs/build/html/_sources/transformerx.txplot.rst.txt
/docs/build/html/transformerx.utils.html
/docs/build/html/_sources/transformerx.utils.rst.txt
85 changes: 85 additions & 0 deletions examples/encoder_only_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import tensorflow as tf
from transformerx.layers.transformer_encoder import TransformerEncoder
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences


# A tiny model TransformerEncoder
def encoder_only_model(vocab_size, d_model, num_heads, n_blocks, max_seq_length, dropout_rate=0.1):
inputs = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32)

# Initialize and apply the TransformerEncoder
transformer_encoder = TransformerEncoder(
vocab_size=vocab_size,
d_model=d_model,
num_heads=num_heads,
n_blocks=n_blocks,
maxlen_position_encoding=max_seq_length,
dropout_rate=dropout_rate,
)

# Apply the transformer encoder to the input sequence
encoder_output, _ = transformer_encoder(inputs)

# Global average pooling to obtain a fixed-size representation
pooled_output = tf.keras.layers.GlobalAveragePooling1D()(encoder_output)

# Dense layer for classification (modify as needed for your specific task)
outputs = tf.keras.layers.Dense(units=1, activation="sigmoid")(pooled_output)

# Build and compile the model
model = tf.keras.Model(inputs=inputs, outputs=outputs)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0)
model.compile(
optimizer=optimizer,
loss="binary_crossentropy",
metrics=["accuracy"],
)

return model


def main():
# Load IMDb dataset
# A collection of 25,000 movie reviews sourced from IMDB, categorized based on sentiment (positive/negative).
# The reviews have undergone preprocessing, and each review is represented as a list of word indexes (integers).

# Customize the following hyperparameters
vocab_size = 1000 # Replace with the actual vocabulary size.
d_model = 8
num_heads = 4
n_blocks = 2
max_seq_length = 100
dropout_rate = 0.1

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size)
x_train = pad_sequences(x_train, maxlen=max_seq_length)
x_train = tf.convert_to_tensor(x_train, dtype=None, dtype_hint=None, name=None)

x_test = pad_sequences(x_test, maxlen=max_seq_length)
x_test = tf.convert_to_tensor(x_test, dtype=None, dtype_hint=None, name=None)

y_train = tf.convert_to_tensor(y_train[:], dtype=tf.int32)
y_test = tf.convert_to_tensor(y_test[:], dtype=tf.int32)

# Building the model

bert_like_model = encoder_only_model(
vocab_size=vocab_size,
d_model=d_model,
num_heads=num_heads,
n_blocks=n_blocks,
max_seq_length=max_seq_length,
dropout_rate=dropout_rate,
)
# Model summary
print(bert_like_model.summary())
bert_like_model.fit(x_train.numpy(), y_train, epochs=5, batch_size=32)

# Evaluating the model on the test set
test_loss, test_accuracy = bert_like_model.evaluate(x_test, y_test, verbose=2)
print(f"\nTest Accuracy: {test_accuracy * 100:.2f}%")


if __name__ == "__main__":
main()
18 changes: 9 additions & 9 deletions transformerx/layers/dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def __init__(
scaled: bool = True,
kernel_initializer: str = "ones",
kernel_regularizer: str = None,
causal_mask: bool = None,
causal_mask: bool = False,
padding_mask: bool = False,
mask_type="dilated",
mask_prob=0.0,
dilation_rate=1,
Expand All @@ -104,7 +105,7 @@ def __init__(
self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer
self.causal_mask = causal_mask

self.padding_mask = padding_mask
self.mask_type = mask_type
self.mask_prob = mask_prob
self.dilation_rate = dilation_rate
Expand Down Expand Up @@ -138,19 +139,18 @@ def call(

# apply causal mask
if self.causal_mask:
# New version of masking
look_ahead_mask = LookAheadMask()
scores = look_ahead_mask(scores)
# todo: get different masks as a single or list of Callable or str objects and then invoke them in a loop
# todo: for performance reasons, first generate the boolean masks and then in the end add up them and then
# multiply them once instead of generating masks and then multiply with 10-9 and add again etc.

# todo: now add the newly implemented padding mask here
# todo: pass the padding mask object or a string denoting it to the __init__()
if self.padding_mask:
padding_mask = PaddingMask()
scores = padding_mask(scores)

# New version of masking
look_ahead_mask = LookAheadMask()
scores = look_ahead_mask(scores)


# to be uncommented later
# apply global mask
# gmask = self.global_mask.get_mask(keys.shape)
Expand Down Expand Up @@ -192,7 +192,7 @@ def main():
[[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]], dtype=tf.float32
)

# assert mask.shape == expected_mask_shape
assert expected_mask_values.shape == expected_mask_shape
# assert tf.reduce_all(tf.equal(mask, expected_mask_values))

print("Global attention mask test passed successfully!")
Expand Down
3 changes: 3 additions & 0 deletions transformerx/layers/positionwise_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def call(self, x, **kwargs):
x = self.dropout(x)
if self.non_linear_proj == "glu":
gate = tf.keras.activations.sigmoid(self.glu(x))
if not x.shape[-1] == self.input_hidden_units:
raise Exception(f"Please make ensure the number of input_hidden_units: {self.input_hidden_units} equals to the "
f"input's {x.shape} last dimension")
return x * gate
elif self.non_linear_proj == "selu":
gate = tf.keras.activations.sigmoid(self.selu(x))
Expand Down
8 changes: 8 additions & 0 deletions transformerx/layers/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ def call(self, queries, attention_mask=None, **kwargs):
embeddings, attention_mask=attention_mask, **kwargs
)
self.attention_weights[i] = attn_weights

# if attn_weights is not None:
# mean_attention = tf.reduce_mean(attn_weights)
# std_attention = tf.math.reduce_std(attn_weights)
# print(
# f"Mean Attention Weights: {mean_attention}, Std Attention Weights: {std_attention}"
# )

return embeddings, self.attention_weights


Expand Down
29 changes: 16 additions & 13 deletions transformerx/layers/transformer_encoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,30 +222,33 @@ def __init__(
Callable
] = None, # Learning rate schedule function
use_bias: bool = False, # Whether to include bias terms in the attention computation
contextualized_embeddings: bool = None, # incorporate pre-trained language models such as BERT or GPT-2 into the model (feedforward networks)
contextualized_embeddings: bool = None,
# incorporate pre-trained language models such as BERT or GPT-2 into the model (feedforward networks)
name: str = "transformer_encoder_block",
dtype: Optional[tf.dtypes.DType] = None,
**kwargs,
):
super(TransformerEncoderBlock, self).__init__(name=name, dtype=dtype, **kwargs)
assert isinstance(d_model, int) and d_model > 0, "Invalid d_model: {}".format(
d_model
)
assert (
isinstance(d_model, int) and d_model > 0
), f"Invalid d_model: {d_model}. It must be integer and greater than 0."
assert (
isinstance(input_hidden_units_ffn, int) and input_hidden_units_ffn > 0
), "Invalid ffn_num_hiddens: {}".format(input_hidden_units_ffn)
), f"Invalid ffn_num_hiddens: {input_hidden_units_ffn}. It must be integer and greater than 0."
assert (
isinstance(num_heads, int) and num_heads > 0 and d_model % num_heads == 0
), "Invalid num_heads: {}".format(num_heads)
), f"Invalid num_heads: {num_heads}. The d_model {d_model} must be divisible by num_heads"
assert (
isinstance(dropout_rate, float) and 0.0 <= dropout_rate <= 1.0
), "Invalid dropout rate: {}".format(dropout_rate)
), f"Invalid dropout rate: {dropout_rate}. It must be between 0.0 and 1.0."
assert norm_type in [
"layer",
"batch",
"instance",
], "Invalid norm_type: {}".format(norm_type)
assert isinstance(use_bias, bool), "Invalid bias: {}".format(use_bias)
], f"Invalid norm_type: {norm_type}. Valid types: 'layer', 'batch', and 'instance'."
assert isinstance(
use_bias, bool
), f"Invalid bias: {use_bias}. It must be a boolean."
if residual_connections is not None:
assert (
len(residual_connections) == 2
Expand All @@ -255,7 +258,7 @@ def __init__(
if clip_norm is not None:
assert (
isinstance(clip_norm, float) and clip_norm > 0.0
), "Invalid clip_norm: {}".format(clip_norm)
), f"Invalid clip_norm: {clip_norm}"
if kernel_initializer is not None:
assert callable(
kernel_initializer
Expand Down Expand Up @@ -324,9 +327,9 @@ def call(self, queries, attention_mask=None, **kwargs):
assert (
len(queries.shape) == 3
), f"Input tensor should have rank 3, got {len(queries.shape)}, {queries.shape}"
assert (
queries.shape[-1] == self.d_model
), "Last dimension of input tensor should be equal to d_model"
# assert (
# queries.shape[-1] == self.d_model
# ), f"Last dimension of input tensor {queries.shape} should be equal to d_model {self.d_model}"
# if attention_mask is not None:
# attention_mask = tf.cast(attention_mask, tf.int32)
# assert (
Expand Down

0 comments on commit 2fea0ac

Please sign in to comment.