-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
66 lines (49 loc) · 1.85 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Building GPT architecture
"""
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
import maximal
from maximal.layers import PositionalEmbedding, GPTLayer
from config import config
def build_model() -> tf.keras.models.Model:
"""
Builds a GPT using Maximal and TensorFlow.
Args: / (just needs config params)
Returns: GPT model (tf.keras.models.Model)
"""
# Define nodes of the graph
input_batch = Input(shape=(config.INPUT_LENGTH,), dtype=tf.int32)
embedding = PositionalEmbedding(config.INPUT_LENGTH, config.VOCAB_SIZE, config.DEPTH)
gpt_layers = [GPTLayer(depth=config.DEPTH, heads=config.HEADS, ff_nodes=config.FF_NODES) for _ in range(config.N_LAYERS)]
classification_layer = Dense(config.VOCAB_SIZE)
# Build the computational graph
x = embedding(input_batch)
for layer in gpt_layers:
x = layer(x)
classification = classification_layer(x)
return Model(
inputs=input_batch,
outputs=classification
)
def load_or_build_model(verbose: bool =False) -> tf.keras.models.Model:
"""
Checks if a model with name MODEL_NAME is already stored in /saved_models
folder. If present, loads the existing one (to train it further). If not, it
builds a new one.
Args:
verbose (bool): print model.summary() or not - defaults to False
"""
filenames = os.listdir(os.path.join(os.getcwd(), "saved_models"))
if config.MODEL_NAME in filenames:
print(f"Loading existing model: {config.MODEL_NAME}.h5")
gpt = tf.keras.models.load_model(os.path.join(os.getcwd(), "saved_models", config.MODEL_NAME))
else:
print(f"Creating a new model: {config.MODEL_NAME}.h5")
gpt = build_model()
if verbose:
print(gpt.summary())
return gpt