Mini GPT#

This notebook will create a mini GPT using the IMDB dataset.

import re, html, random, numpy as np, tensorflow as tf, keras
import keras_nlp as knlp
from pathlib import Path
#tf.keras.utils.set_random_seed(42)
2025-05-08 15:04:33.336645: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:04:33.339825: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:04:33.348272: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746716673.361923   54121 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746716673.366050   54121 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746716673.377498   54121 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716673.377508   54121 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716673.377510   54121 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716673.377512   54121 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-05-08 15:04:33.381583: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 2
      1 import re, html, random, numpy as np, tensorflow as tf, keras
----> 2 import keras_nlp as knlp
      3 from pathlib import Path
      4 #tf.keras.utils.set_random_seed(42)

ModuleNotFoundError: No module named 'keras_nlp'

Create the dataset#

  1. load_imdb(): Downloads the IMDB dataset once; each review is decoded from bytes and HTML entities are resolved.

  2. basic_clean(): Cleans the text by removing HTML tags, non-ASCII symbols, and collapsing multiple spaces into a single space. It also converts everything to lowercase and ignores empty lines.

  3. Shuffle the cleaned list so that train/validation splits are random.

  4. compute_word_piece_vocabulary(): Computes the WordPiece vocabulary from the cleaned text. It uses a greedy algorithm to merge the most frequent pairs of characters until it reaches a specified vocabulary size.

  5. WordPieceTokenizer(...) uses the vocabulary map to convert text to fixed-length ID sequences. The extra + 1 token reserves room for the “next token” label.

  6. make_ds(): Creates a TensorFlow dataset from the tokenized text. It tokenizes each review, discards sequences shorter than two tokens, splits the data into input and target sequences, shuffles the dataset, pads the sequences to a uniform length, batches them, and prefetches them for efficient training.

  7. Train/val split: 90 % of the shuffled corpus feeds train_ds, the remaining 10% val_ds.

def load_imdb():
    import tensorflow_datasets as tfds
    raw = tfds.load("imdb_reviews", split="train+test", shuffle_files=True)
    text = [html.unescape(x["text"].numpy().decode()) for x in raw]
    return text

def basic_clean(lines):
    out = []
    pattern = re.compile(r"<[^>]*>|[^A-Za-z0-9 ,.!?'\n]")
    for ln in lines:
        ln = pattern.sub(" ", ln).lower()
        ln = re.sub(r"\s+", " ", ln).strip()
        if ln: 
            out.append(ln)
    return out

raw_text = basic_clean(load_imdb())
random.shuffle(raw_text)  
#raw_text   = raw_text[:200_000] # trim
print(f"Corpus: {len(raw_text):,} lines")


# ---------------------------------------------------------------------
# 1.  Tokeniser  
# ---------------------------------------------------------------------
vocab_size = 8_000
SEQ_LEN    = 256      

text_ds = tf.data.Dataset.from_tensor_slices(raw_text)

vocab = knlp.tokenizers.compute_word_piece_vocabulary(
    data=text_ds,
    vocabulary_size=vocab_size,
    lowercase=True,
)

tokenizer = knlp.tokenizers.WordPieceTokenizer(
    vocabulary=vocab,
    sequence_length=SEQ_LEN + 1,     
    lowercase=True,
    oov_token="[UNK]",
)

def make_ds(texts, batch=64):     
    toks = tokenizer(texts)
    toks = [t[:SEQ_LEN+1] for t in toks if len(t) > 1]

    ds = tf.data.Dataset.from_tensor_slices(toks)

    def xy(tokens):
        return {"tokens": tokens[:-1]}, tokens[1:]

    return (ds.map(xy, num_parallel_calls=tf.data.AUTOTUNE)
              .shuffle(50_000)
              .padded_batch(batch, drop_remainder=True)
              .prefetch(tf.data.AUTOTUNE))

train_split = int(0.9*len(raw_text))
train_ds = make_ds(raw_text[:train_split])
val_ds   = make_ds(raw_text[train_split:])
2025-04-28 17:49:23.271400: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 8388608
2025-04-28 17:49:25.337619: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Corpus: 50,000 lines
2025-04-28 17:49:33.993948: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
for txt in [
    "movie review", 
    "Transformer", 
    "xqzj", 
    "abcdefghijklmnopqrstuvwxyz"]:

    ids = tokenizer(txt)
    subtokens = [tokenizer.vocabulary[i] 
                 for i in ids.numpy() if i != 0]
    print(f"{txt:15}{subtokens}")
movie review    → ['movie', 'review']
Transformer     → ['t', '##ran', '##s', '##form', '##er']
xqzj            → ['x', '##q', '##z', '##j']
abcdefghijklmnopqrstuvwxyz → ['abc', '##de', '##f', '##gh', '##i', '##j', '##k', '##lm', '##no', '##p', '##q', '##rst', '##u', '##v', '##w', '##x', '##y', '##z']
print("Vocab size:", len(vocab))                 # 8 000
ids = tokenizer("This movie was great!")
print("Token IDs :", ids)
print("Back to txt:", tokenizer.detokenize(ids))
Vocab size: 7905
Token IDs : tf.Tensor(
[ 53  57  55 125   5   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0], shape=(257,), dtype=int32)
Back to txt: this movie was great ! [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

Define the model#

  1. Input tokens

    • Each IMDB review is already tokenised into WordPiece IDs, giving an input tensor of shape (batch, T) with T SEQ_LEN.

  2. Positional Embedding layer

    • Token embeddings learn a d_model-dimensional vector for every vocabulary entry.

    • Position embeddings learn a d_model-dimensional vector for positions 0 SEQ_LEN 1.

    • The layer adds the two vectors so the model receives both the token identity (what) and its location (where).

  3. Causal mask

    • A lower-triangular (T, T) mask ensures that, at time step t, the model attends only to positions t.

    • This enforces the left-to-right, next-token-prediction objective.

  4. GPTBlock  — repeated DEPTH = 8 times
    Each block contains two residual sub-layers, both preceded by Layer Normalisation and followed by dropout p = 0.1.

    Sub-layer

    Purpose

    Key details

    Multi-Head Self-Attention

    Lets each token look back at earlier tokens and weigh their relevance.

    HEADS = 8, key/query size d_model / HEADS = 32, masked attention, dropout 0.1.

    Feed-Forward Network

    Refines each token representation independently.

    Two dense layers: width expands to 4 × d_model with GELU activation, then projects back to d_model; dropout inside and after.

  5. Language-model head

    • A final dense layer of size vocab_size projects each d_model vector to logits over the vocabulary.

    • Training uses Sparse Categorical Cross-Entropy to predict the next token at every position.

  6. Hyper-parameters

    • d_model = 256 keeps memory ≈ 6 GB (FP16).

    • DEPTH = 8 gives enough depth without long runtimes.

    • HEADS = 8 each head processes 32-dim keys & queries.

    • Dropout p = 0.1 mitigates over-fitting on the small IMDB corpus.

Together these pieces implement the core ideas behind GPT-style language models: position-aware token embeddings, masked self-attention for autoregression, and stacked attention/MLP blocks that build hierarchical representations while preserving gradient flow through residual connections.

class PositionalEmbedding(keras.layers.Layer):
    def __init__(self, vocab, d_model, max_len):
        super().__init__()
        self.tok = keras.layers.Embedding(vocab, d_model)
        self.pos = keras.layers.Embedding(max_len, d_model)

    def call(self, x):
        idx = tf.range(tf.shape(x)[-1])[None]
        return self.tok(x) + self.pos(idx)

## A version for debugging and illustration
class PositionalEmbeddingDebug(PositionalEmbedding):
    def call(self, x, return_parts=False):
        idx = tf.range(tf.shape(x)[-1])[None]
        tok = self.tok(x)
        pos = self.pos(idx)
        return (tok, pos) if return_parts else tok + pos
# --- Toy demo wrapped in its own scope -------------------------------
def show_positional_demo():

    vocab_size, d_model, max_len = 6, 4, 10
    layer  = PositionalEmbeddingDebug(vocab_size, d_model, max_len)

    tokens = tf.constant([[1, 4, 1, 3, 0]])      # shape (batch=1, T=5)

    print("Apply PositionalEmbedding: \n", layer(tokens))  

    token_embs, pos_embs   = layer(tokens, return_parts=True)              # (1,5,4)

    # 2) combined input
    combined   = token_embs + pos_embs           # identical to layer(tokens)

    # Nicely formatted printout
    print("Token IDs:        ", tokens.numpy())
    print("\nE_token (word vectors):")
    print(token_embs.numpy())
    print("\nE_pos (position vectors):")
    print(pos_embs.numpy())
    print("\nSum fed to Transformer:")
    print(combined.numpy())

show_positional_demo()
Apply PositionalEmbedding: 
 tf.Tensor(
[[[ 0.09493951 -0.09283178  0.01114397 -0.03076396]
  [ 0.03453743 -0.00263811 -0.02565154 -0.01494424]
  [ 0.0019802  -0.09410468  0.00568523  0.00466434]
  [ 0.00854168  0.06907481 -0.00520728  0.04899843]
  [-0.02489221 -0.01527636 -0.03607261 -0.09318761]]], shape=(1, 5, 4), dtype=float32)
Token IDs:         [[1 4 1 3 0]]

E_token (word vectors):
[[[ 0.04694815 -0.04938451  0.03270758  0.00196652]
  [-0.00088028  0.01174162 -0.03452749  0.01346072]
  [ 0.04694815 -0.04938451  0.03270758  0.00196652]
  [-0.03093035  0.0200041  -0.00651758  0.03676522]
  [-0.0347316   0.00906425 -0.02840428 -0.04868952]]]

E_pos (position vectors):
[[[ 0.04799136 -0.04344727 -0.02156361 -0.03273048]
  [ 0.03541771 -0.01437973  0.00887595 -0.02840496]
  [-0.04496795 -0.04472017 -0.02702235  0.00269781]
  [ 0.03947203  0.0490707   0.0013103   0.01223321]
  [ 0.00983939 -0.02434061 -0.00766833 -0.04449809]]]

Sum fed to Transformer:
[[[ 0.09493951 -0.09283178  0.01114397 -0.03076396]
  [ 0.03453743 -0.00263811 -0.02565154 -0.01494424]
  [ 0.0019802  -0.09410468  0.00568523  0.00466434]
  [ 0.00854168  0.06907481 -0.00520728  0.04899843]
  [-0.02489221 -0.01527636 -0.03607261 -0.09318761]]]
from keras import layers

def causal_mask(n):
    i = tf.range(n)[:, None]
    j = tf.range(n)[None, :]
    m = i >= j
    return tf.cast(m, tf.int32)[None, None]


class GPTBlock(layers.Layer):
    def __init__(self, d_model, heads, p_drop=0.1):
        super().__init__()
        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn  = layers.MultiHeadAttention(
            num_heads=heads,
            key_dim=d_model // heads,
            dropout=p_drop       
        )
        self.drop1 = layers.Dropout(p_drop) 

        self.norm2 = layers.LayerNormalization(epsilon=1e-5)
        self.ff    = keras.Sequential([
            layers.Dense(4*d_model, activation="gelu"),
            layers.Dropout(p_drop),        
            layers.Dense(d_model),
        ])
        self.drop2 = layers.Dropout(p_drop)   

    def call(self, x, training=False):
        mask = causal_mask(tf.shape(x)[1])

        # attention branch
        attn_out = self.attn(
            self.norm1(x), 
            self.norm1(x),
            attention_mask=mask,
            training=training,
        )
        x = x + self.drop1(attn_out, training=training)

        # feed-forward branch
        ffn_out = self.ff(self.norm2(x), training=training)
        x = x + self.drop2(ffn_out, training=training)
        return x

# ---------------------------------------------------------------------
# 2.  Model  
# ---------------------------------------------------------------------
d_model = 256          # keep width so RAM stays ~6 GB fp16
DEPTH   = 8            
HEADS   = 8            # ★ d_model // HEADS = 32/key

inp = keras.Input(shape=(SEQ_LEN,), dtype="int32", name="tokens")
x   = PositionalEmbedding(vocab_size, d_model, SEQ_LEN)(inp)
for _ in range(DEPTH):
    x = GPTBlock(d_model, HEADS)(x)
logits = keras.layers.Dense(vocab_size)(x)

model = keras.Model(inp, logits, name="mini_gpt")
model.summary()
Model: "mini_gpt"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ tokens (InputLayer)             │ (None, 256)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ positional_embedding            │ (None, 256, 256)       │     2,113,536 │
│ (PositionalEmbedding)           │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block (GPTBlock)            │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block_1 (GPTBlock)          │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block_2 (GPTBlock)          │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block_3 (GPTBlock)          │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block_4 (GPTBlock)          │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block_5 (GPTBlock)          │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block_6 (GPTBlock)          │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gpt_block_7 (GPTBlock)          │ (None, 256, 256)       │       789,760 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_16 (Dense)                │ (None, 256, 8000)      │     2,056,000 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 10,487,616 (40.01 MB)
 Trainable params: 10,487,616 (40.01 MB)
 Non-trainable params: 0 (0.00 B)
mask = causal_mask(10)   # shape (1, 1, 4, 4)
mask[0, 0]               # drop the leading singleton dims for display
<tf.Tensor: shape=(10, 10), dtype=int32, numpy=
array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)>
def show_gpt_layer_demo():
    import numpy as np

    #np.random.seed(0)

    batch, T, d_model = 1, 4, 6          # 4-token sentence, 6-dim features
    keep_prob_attn    = 0.9              # dropout inside attention weights
    keep_prob_residual= 0.9              # dropout on attn_out
    scale             = d_model ** -0.5  # 1/√d  for dot-product

    # 1) fake input: token representations coming from previous layer
    x = np.random.randn(batch, T, d_model).round(3)
    print("Input x:\n", x, "\n")

    # 2) LayerNorm (per token)
    mu  = x.mean(-1, keepdims=True)
    var = x.var (-1, keepdims=True)
    x_norm = (x - mu) / np.sqrt(var + 1e-5)
    print("LayerNorm x:\n", x_norm.round(3), "\n")

    # 3) linear projections Q, K, V (random weights for the demo)
    W_q, W_k, W_v = [np.random.randn(d_model, d_model) for _ in range(3)]
    Q = x_norm @ W_q
    K = x_norm @ W_k
    V = x_norm @ W_v

    # 4) causal mask: forbid looking right
    mask = np.tril(np.ones((T, T), dtype=bool))   # (T,T) lower-triangle
    scores = (Q @ K.transpose(0,2,1)) * scale     # (B,T,T)
    scores[:, ~mask]  = -1e9                          # set future positions to −∞

    # 5) soft-max to get attention weights
    def softmax(a, axis=-1):
        a_exp = np.exp(a - a.max(axis=axis, keepdims=True))
        return a_exp / a_exp.sum(axis=axis, keepdims=True)
    weights = softmax(scores, axis=-1)

    # 6) **Attention-weight dropout** (drop some links)
    drop_mask_attn = (np.random.rand(*weights.shape) < keep_prob_attn)
    weights_drop   = weights * drop_mask_attn / keep_prob_attn   # rescale to keep expectation

    # 7) weighted sum → attn_out
    attn_out = weights_drop @ V
    print("attn_out BEFORE residual dropout:\n", attn_out.round(3), "\n")

    # 8) **Residual-dropout** on attn_out
    drop_mask_res = (np.random.rand(*attn_out.shape) < keep_prob_residual)
    attn_out_rd   = attn_out * drop_mask_res / keep_prob_residual

    # 9) residual add
    y = x + attn_out_rd
    print("attn_out AFTER residual dropout:\n", attn_out_rd.round(3), "\n")
    print("Result y = x + dropout(attn_out):\n", y.round(3))

show_gpt_layer_demo()
Input x:
 [[[ 1.266  0.526 -0.397 -0.039  0.399 -1.112]
  [ 0.009 -0.185 -0.702  0.195  1.029 -0.787]
  [ 0.544 -0.053 -0.347  0.102 -1.157  0.522]
  [ 0.066  0.275 -0.756  0.473  0.309  0.013]]] 

LayerNorm x:
 [[[ 1.547  0.559 -0.673 -0.195  0.39  -1.628]
  [ 0.136 -0.184 -1.035  0.442  1.816 -1.175]
  [ 1.05   0.02  -0.487  0.288 -1.884  1.013]
  [ 0.007  0.533 -2.062  1.031  0.618 -0.127]]] 

attn_out BEFORE residual dropout:
 [[[ 0.4   -0.294 -0.728 -2.582  1.524  1.285]
  [ 0.383 -0.284 -0.788 -2.569  1.508  1.252]
  [ 3.497  0.889 -1.177  3.68  -1.785  0.524]
  [ 0.346 -0.161 -1.306 -2.412  1.388  1.15 ]]] 

attn_out AFTER residual dropout:
 [[[ 0.444 -0.327 -0.808 -2.868  1.693  1.428]
  [ 0.425 -0.316 -0.875 -2.855  1.675  1.391]
  [ 3.886  0.988 -1.308  4.089 -1.984  0.583]
  [ 0.384 -0.179 -1.452 -2.68   1.542  1.277]]] 

Result y = x + dropout(attn_out):
 [[[ 1.71   0.199 -1.205 -2.907  2.092  0.316]
  [ 0.434 -0.501 -1.577 -2.66   2.704  0.604]
  [ 4.43   0.935 -1.655  4.191 -3.141  1.105]
  [ 0.45   0.096 -2.208 -2.207  1.851  1.29 ]]]

Compile the model#

We use the AdamW optimizer with weight decay, a learning rate of 1e-4, and a batch size of 64. The model is trained for 10 epochs with a learning rate schedule that decays the learning rate by 0.1 every 3 epochs.

The training loss is monitored using the sparse categorical cross-entropy loss function:

Position t

True next token

Model’s soft-max probability

Token-level loss

0

"I"

0.40

(-\log 0.40 \approx 0.92)

1

"love"

0.05

(-\log 0.05 \approx 2.99)

2

"this"

0.60

(-\log 0.60 \approx 0.51)

3

"movie"

0.10

(-\log 0.10 \approx 2.30)

Average loss

\[ \frac{0.92 + 2.99 + 0.51 + 2.30}{4} \;\approx\; 1.68 \]

Keras reports this 1.68 during training; the optimiser tries to push it lower by assigning higher probability to the correct next token at each position.

The model is saved after each epoch, and the best model is selected based on the validation loss.

# Cosine-decay restart LR schedule (two-epoch cycles)
steps_per_epoch = len(train_ds)
lr_sched = keras.optimizers.schedules.CosineDecayRestarts(
    initial_learning_rate=3e-4,
    first_decay_steps=steps_per_epoch*2,
)
opt = keras.optimizers.AdamW(lr_sched, weight_decay=1e-4)

loss  = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(opt, loss, metrics=["accuracy"])

Train the model#

EPOCHS = 8
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[
        keras.callbacks.EarlyStopping(patience=2,
                                      restore_best_weights=True),
        keras.callbacks.ModelCheckpoint("mini_gpt.keras",
                                        save_best_only=True),
    ],
)
Epoch 1/8
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/keras/src/models/functional.py:238: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: tokens
Received: inputs=['Tensor(shape=(64, 256))']
  warnings.warn(msg)
703/703 ━━━━━━━━━━━━━━━━━━━━ 3690s 5s/step - accuracy: 0.2922 - loss: 4.7974 - val_accuracy: 0.3369 - val_loss: 4.0557
Epoch 2/8
703/703 ━━━━━━━━━━━━━━━━━━━━ 3656s 5s/step - accuracy: 0.3345 - loss: 4.0513 - val_accuracy: 0.3444 - val_loss: 3.9677
Epoch 3/8
703/703 ━━━━━━━━━━━━━━━━━━━━ 3659s 5s/step - accuracy: 0.3415 - loss: 3.9703 - val_accuracy: 0.3625 - val_loss: 3.7509
Epoch 4/8
703/703 ━━━━━━━━━━━━━━━━━━━━ 3704s 5s/step - accuracy: 0.3575 - loss: 3.7548 - val_accuracy: 0.3740 - val_loss: 3.6263
Epoch 5/8
703/703 ━━━━━━━━━━━━━━━━━━━━ 3647s 5s/step - accuracy: 0.3683 - loss: 3.6355 - val_accuracy: 0.3796 - val_loss: 3.5723
Epoch 6/8
703/703 ━━━━━━━━━━━━━━━━━━━━ 3644s 5s/step - accuracy: 0.3735 - loss: 3.5783 - val_accuracy: 0.3809 - val_loss: 3.5615
Epoch 7/8
703/703 ━━━━━━━━━━━━━━━━━━━━ 3557s 5s/step - accuracy: 0.3708 - loss: 3.6077 - val_accuracy: 0.3848 - val_loss: 3.5232
Epoch 8/8
703/703 ━━━━━━━━━━━━━━━━━━━━ 3551s 5s/step - accuracy: 0.3792 - loss: 3.5232 - val_accuracy: 0.3918 - val_loss: 3.4593

Run the model#

PAD  = tokenizer.token_to_id("[PAD]")
UNK  = tokenizer.token_to_id("[UNK]")          # handy for debugging
MASK = tokenizer.token_to_id("[MASK]")         # not used here but nice to have
vocab_size = tokenizer.vocabulary_size()

print("PAD id :", PAD,  "UNK id :", UNK, "vocab :", vocab_size)
PAD id : 0 UNK id : 3 vocab : 7905
import numpy as np, tensorflow as tf

# ------------------------------------------------------------------
# 0.  Constants  (define once, reuse everywhere)
# ------------------------------------------------------------------
# constants (already defined)
SEQ_LEN    = 256
VOCAB_DIM  = 8000
REAL_VOCAB = tokenizer.vocabulary_size()
PAD_ID     = tokenizer.token_to_id("[PAD]")
UNK_ID     = tokenizer.token_to_id("[UNK]")

def top_k_logits(logits, k=40):
    vals, _   = tf.math.top_k(logits, k=k)
    min_vals  = vals[..., -1, None]
    return tf.where(logits < min_vals, tf.float32.min, logits)

import re, textwrap

def tidy(text, width = 80):
    """Fix spacing, punctuation, sentence caps, and wrap to `width` columns."""
    
    # remove space before punctuation  –>  "word , ..."  → "word, ..."
    text = re.sub(r"\s+([.,!?;:])", r"\1", text)

    # ensure single space after punctuation
    text = re.sub(r"([.,!?;:])([^\s])", r"\1 \2", text)

    # fix contractions
    text = re.sub(r"\b(\w+)\s+'\s*([sSdDmMnt]|re|ve|ll)\b", r"\1'\2", text)

    # capitalise " i " → " I "  (pronoun) and sentence starts
    def cap_sentence(m):
        return m.group(1) + m.group(2).upper()
    text = re.sub(r"(^|[.!?]\s+)([a-z])", cap_sentence, text)
    text = re.sub(r"\bi\b", "I", text)

    # collapse multiple spaces, strip ends
    text = re.sub(r"\s{2,}", " ", text).strip()

    # wrap into neat paragraphs
    return textwrap.fill(text, width)

def sample(prompt, max_new=80, temperature=1.0, k=40):
    ids = tokenizer(prompt).numpy().tolist()
    ids = ids if isinstance(ids[0], int) else ids[0]

    while ids and ids[-1] == PAD_ID:
        ids.pop()

    for _ in range(max_new):
        ctx = ids[-SEQ_LEN:]
        x   = np.array(ctx + [PAD_ID]*(SEQ_LEN-len(ctx)))[None]

        logits  = model.predict(x, verbose=0)[0, len(ctx)-1]
        if temperature == 0.0:                 # ★ greedy decode
            next_id = int(np.argmax(logits))
        else:                                  # ★ stochastic decode
            logits  = top_k_logits(logits / temperature, k).numpy()
            next_id = np.random.choice(VOCAB_DIM, p=tf.nn.softmax(logits).numpy())

        if next_id >= REAL_VOCAB:
            next_id = UNK_ID

        ids.append(next_id)
        if next_id == PAD_ID:
            break

    return tidy(tokenizer.detokenize([ids])[0].strip())

print(sample("the movie was"))
The movie was a wonderful mixture of characters, and most of the actors really
were believable. The script is not bad but a great film. I would say that for
another person involved should make this movie. Well then we should see why the
actors are such a truly bad movie. But because I'd put the soundtrack by someone
who actually made it very believable. I think it's a little too. [PAD]
print(sample("the movie was ", max_new=200, temperature=1.0, k=40))
The movie was great. If this movie was produced, a copy of the original vhs
version, the dvd will appeal to your dvd collection and it. It may be very
effective, but as I can say, the whole thing at the end was that it was so real
for me was too much to find. .. But a real piece of work of view. .. And if you
are willing to enjoy the whole movie to go further then this should be done
better. It would have been true to watch the original, it will be true to the
time and so much more accurate that that the movie has been made by a scummer.
This movie is a shame for the director. It has my heart that shows an excellent
book of the movie. . So I could not recommend this film to everyone, but I wish
to watch it. [PAD]