(rnn =)

Recurrent Neural Networks (RNNs)#

A Recurrent Neural Network (RNN) is a neural architecture designed for sequential data. Unlike feedforward networks, it contains recurrent connections in the hidden layer: the hidden state \(h_t\) at time \(t\) depends on both the current input \(x_t\) and the previous hidden state \(h_{t-1}\). Because the same weights are reused at every step, the model can, in principle, handle sequences of arbitrary length without increasing the parameter count.

Components of an RNN#

  1. Input layer – receives a sequence of vectors \((x_1,\dots,x_T)\), e.g., word embeddings.

  2. Recurrent (hidden) layer(s) – update rule below; weights are shared across all \(t\).

  3. Output layer – either emits a prediction \(y_t\) at each step, or reads the final state \(h_T\) to classify the whole sequence.

  4. Recurrent connection – the link \(h_{t-1}\rightarrow h_t\) that carries memory forward.

Update Rule#

\[ h_t = \sigma(W_{xh} x_t + W_{hh} h_{t-1} + b_h) \]
\[ y_t = g(W_{hy} h_t + b_y) \]

where:

  • \(x_t\in\mathbb{R}^{d_{\text{in}}}\) – input at time \(t\).

  • \(h_t\in\mathbb{R}^{d_h}\) – hidden state (memory).

  • \(y_t\) – output; in sequence classification we often set \(y_t = y\) only for \(t=T\).

  • \(W_{xh}\), \(W_{hh}\), \(W_{hy}\) and biases \(b_h\), \(b_y\) – shared, learnable parameters.

  • \(\sigma\) – typically \(\tanh\) or \(\text{ReLU}\).

  • \(g\) – task‑dependent: softmax (language modelling), sigmoid (binary label), or identity (feature for another layer).

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import seaborn as sns
2025-05-08 15:05:46.107189: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:05:46.110437: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:05:46.118949: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746716746.133132   71957 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746716746.137251   71957 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746716746.148819   71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716746.148831   71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716746.148833   71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716746.148834   71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-05-08 15:05:46.153017: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

tf.random.set_seed(1); np.random.seed(1)

# ==========================================================
# 1. Synthetic data: sine wave that changes frequency every 400 steps
#    + a little noise
# ==========================================================
SEQ_LEN       = 50
PRED_HORIZON  = 50      
N_STEPS       = 200

def regime_wave(n, switch_every=100, noise=0.05):
    f1, f2 = 0.20, 0.05
    wave   = [np.sin((f1 if (i//switch_every)%2==0 else f2) * i) for i in range(n)]
    return np.array(wave) #+ noise*np.random.randn(n)

series = regime_wave(N_STEPS)


plt.figure(figsize=(12, 4))
plt.plot(series, label='series')
plt.title('Synthetic data: sine wave with changing frequency')
plt.xlabel('Time step')
plt.ylabel('Value')
Text(0, 0.5, 'Value')
../_images/1116836350848fbd3b4870a13f61b880519aa4f5755f1a497c3cf3c0ca813b90.png
# ==========================================================
# 0. Imports and seeds
# ==========================================================
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

tf.random.set_seed(1); np.random.seed(1)

# ==========================================================
# 1. Synthetic data: sine wave that changes frequency every 400 steps
#    + a little noise
# ==========================================================
SEQ_LEN       = 50
PRED_HORIZON  = 50      
N_STEPS       = 4_000

def regime_wave(n, switch_every=400, noise=0.05):
    f1, f2 = 0.20, 0.05
    wave   = [np.sin((f1 if (i//switch_every)%2==0 else f2) * i) for i in range(n)]
    return np.array(wave) + noise*np.random.randn(n)

series = regime_wave(N_STEPS)

def make_windows(data, win, horiz):
    X, y = [], []
    for start in range(len(data) - win - horiz):
        X.append(data[start:start+win])
        y.append(data[start+win:start+win+horiz])
    return np.array(X)[..., None], np.array(y)[..., None]

X, y = make_windows(series, SEQ_LEN, PRED_HORIZON)   # (3900, 50, 1)  (3900, 50, 1)

cut            = 2800                                # first 2800 windows for training
X_tr, y_tr     = X[:cut], y[:cut]
X_val, y_val   = X[cut:], y[cut:]
y_tr_vec       = y_tr.squeeze(-1)                    # Dense/Conv need (batch, 50)
y_val_vec      = y_val.squeeze(-1)

# ==========================================================
# 2. Three architectures: Dense, Conv1D, SimpleRNN
# ==========================================================
def dense_model():
    m = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(SEQ_LEN,1)),
        tf.keras.layers.Dense(64, activation="relu"),
        tf.keras.layers.Dense(PRED_HORIZON)
    ])
    m.compile("adam", "mse")
    return m

def conv_model():
    m = tf.keras.Sequential([
        tf.keras.layers.Conv1D(64, 5, activation="relu", input_shape=(SEQ_LEN,1)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(PRED_HORIZON)
    ])
    m.compile("adam", "mse")
    return m

def rnn_model():
    m = tf.keras.Sequential([
        tf.keras.layers.SimpleRNN(64, input_shape=(SEQ_LEN,1)),
        tf.keras.layers.Dense(PRED_HORIZON)
    ])
    m.compile("adam", "mse")
    return m

models = dict(Dense=dense_model(), Conv1D=conv_model(), RNN=rnn_model())
hists  = {}

# ==========================================================
# 3. Train all three
# ==========================================================
for name, mdl in models.items():
    print(f"Training {name:5s}")
    h = mdl.fit(X_tr,
                y_tr_vec if name!="RNN" else y_tr_vec,
                validation_data=(X_val, y_val_vec),
                epochs=40, batch_size=64, verbose=0)
    hists[name] = h.history["val_loss"]
    print(f"   final val MSE: {h.history['val_loss'][-1]:.4f}")

# ==========================================================
# 4. Plot validation loss (log scale)
# ==========================================================
plt.figure(figsize=(6,4))
for name, losses in hists.items():
    plt.plot(losses, label=name)
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("val MSE (log)")
plt.title("SimpleRNN  outperforms Dense & Conv1D")
plt.legend()
plt.show()
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(**kwargs)
2025-05-08 15:05:48.792282: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/keras/src/layers/rnn/rnn.py:200: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(**kwargs)
Training Dense
   final val MSE: 0.0856
Training Conv1D
   final val MSE: 0.0804
Training RNN  
   final val MSE: 0.0715
../_images/6ca301ef000416f8990c518e2c6b3791ae2188519532726420c72d9036078436.png
# ==========================================================
# 6. Visualise one forecast window
# ==========================================================
def forecast_trace(model, X_one, y_true, label, color):
    past  = X_one.squeeze()                 # 50‑step context
    truth = y_true.squeeze()                # 50‑step ground truth
    pred  = model.predict(X_one[None, ...], verbose=0).squeeze()

    t_past = np.arange(len(past))
    t_fut  = np.arange(len(past), len(past)+len(truth))

    plt.plot(t_past, past, color="black")
    plt.plot(t_fut, truth, marker="o", color="green",
             label="ground truth" if label=="Dense" else None)
    plt.plot(t_fut, pred,  marker="x", color=color, label=label)

# Pick any validation example
idx = 100
plt.figure(figsize=(8,3))
forecast_trace(models["Dense"],  X_val[idx], y_val[idx], "Dense",  "red")
forecast_trace(models["Conv1D"], X_val[idx], y_val[idx], "Conv1D", "orange")
forecast_trace(models["RNN"],    X_val[idx], y_val[idx], "RNN",    "blue")
plt.title("Next-50-step forecast on validation window")
plt.legend()
sns.despine()
plt.show()
../_images/f8d1f65faba219593fefb67836374fb2f161f592b5b8bdb57177b2116d6fba60.png

Example - Sentiment Analysis of Movie Reviews#

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# ---------------- data ----------------
vocab_size = 10_000
maxlen     = 200        # trim / pad reviews to 200 tokens

(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=vocab_size)
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen, padding="post")
x_test  = keras.preprocessing.sequence.pad_sequences(x_test,  maxlen=maxlen, padding="post")

# ---------------- model ----------------
embed_dim    = 64
hidden_units = 32

model = keras.Sequential([
    layers.Embedding(vocab_size, embed_dim, mask_zero=True),  # word → vector
    layers.SimpleRNN(hidden_units),                           # processes the sequence
    layers.Dense(1, activation="sigmoid")                     # one sentiment score
])

model.compile(optimizer="adam",
              loss="binary_crossentropy",
              metrics=["accuracy"])

model.fit(x_train, y_train,
          epochs=2,
          batch_size=64,
          validation_split=0.2)

print("Test accuracy:", model.evaluate(x_test, y_test, verbose=0)[1])
Epoch 1/2
  1/313 ━━━━━━━━━━━━━━━━━━━━ 7:12 1s/step - accuracy: 0.4375 - loss: 0.7079

  3/313 ━━━━━━━━━━━━━━━━━━━━ 9s 31ms/step - accuracy: 0.4540 - loss: 0.7042

  5/313 ━━━━━━━━━━━━━━━━━━━━ 9s 31ms/step - accuracy: 0.4599 - loss: 0.7052

  7/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.4616 - loss: 0.7056

  9/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.4693 - loss: 0.7044

 11/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.4747 - loss: 0.7034

 13/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4795 - loss: 0.7025

 15/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4849 - loss: 0.7016

 17/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4893 - loss: 0.7008

 19/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4931 - loss: 0.7001

 21/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4963 - loss: 0.6994

 23/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4989 - loss: 0.6990

 25/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5012 - loss: 0.6986

 27/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5032 - loss: 0.6982

 29/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5049 - loss: 0.6979

 31/313 ━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5063 - loss: 0.6976

 33/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5075 - loss: 0.6974

 35/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5086 - loss: 0.6971

 37/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5093 - loss: 0.6969

 39/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5101 - loss: 0.6967

 41/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5108 - loss: 0.6965

 43/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5114 - loss: 0.6963

 45/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5119 - loss: 0.6961

 47/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5125 - loss: 0.6960

 49/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5132 - loss: 0.6958

 51/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5138 - loss: 0.6956

 53/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5145 - loss: 0.6954

 55/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5151 - loss: 0.6952

 57/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5158 - loss: 0.6950

 59/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5164 - loss: 0.6949

 61/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5170 - loss: 0.6947

 63/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5176 - loss: 0.6945

 65/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5183 - loss: 0.6943

 67/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5190 - loss: 0.6941

 69/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5198 - loss: 0.6939

 71/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5204 - loss: 0.6937

 73/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5211 - loss: 0.6936

 75/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5216 - loss: 0.6934

 77/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5222 - loss: 0.6933

 79/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5227 - loss: 0.6931

 81/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5231 - loss: 0.6930

 83/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5235 - loss: 0.6929

 85/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5239 - loss: 0.6928

 87/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5243 - loss: 0.6927

 89/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5246 - loss: 0.6926

 91/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5250 - loss: 0.6925

 93/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5254 - loss: 0.6924

 95/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5257 - loss: 0.6923

 97/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5261 - loss: 0.6922

 99/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5264 - loss: 0.6921

101/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5268 - loss: 0.6920

103/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5271 - loss: 0.6919

105/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5275 - loss: 0.6918

107/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5278 - loss: 0.6917

109/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5282 - loss: 0.6915

111/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5285 - loss: 0.6914

113/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5289 - loss: 0.6913

115/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5293 - loss: 0.6912

117/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5297 - loss: 0.6911

119/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5300 - loss: 0.6910

121/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5304 - loss: 0.6909

123/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5307 - loss: 0.6908

125/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5310 - loss: 0.6907

127/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5314 - loss: 0.6906

129/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5317 - loss: 0.6905

131/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5321 - loss: 0.6904

133/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5325 - loss: 0.6902

135/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5328 - loss: 0.6901

137/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5332 - loss: 0.6900

139/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5336 - loss: 0.6899

141/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5340 - loss: 0.6898

143/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5344 - loss: 0.6896

145/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5348 - loss: 0.6895

147/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5352 - loss: 0.6894

149/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5356 - loss: 0.6892

151/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5360 - loss: 0.6891

153/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5364 - loss: 0.6890

155/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5368 - loss: 0.6888

157/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5373 - loss: 0.6886

159/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5377 - loss: 0.6885

161/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5381 - loss: 0.6883

163/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5385 - loss: 0.6881

165/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5390 - loss: 0.6880

167/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5393 - loss: 0.6878

169/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5397 - loss: 0.6876

171/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5401 - loss: 0.6875

173/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5405 - loss: 0.6873

175/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5409 - loss: 0.6872

177/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5413 - loss: 0.6870

179/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5417 - loss: 0.6868

181/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5421 - loss: 0.6866

183/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5425 - loss: 0.6865

185/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5429 - loss: 0.6863

187/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5433 - loss: 0.6861

189/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5438 - loss: 0.6859

191/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5442 - loss: 0.6857

193/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5446 - loss: 0.6855

195/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5451 - loss: 0.6853

197/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5455 - loss: 0.6851

199/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5460 - loss: 0.6848

201/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5465 - loss: 0.6846

203/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5469 - loss: 0.6844

205/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5474 - loss: 0.6841

207/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5479 - loss: 0.6838

209/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5484 - loss: 0.6836

211/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5489 - loss: 0.6833

213/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5494 - loss: 0.6830

215/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5499 - loss: 0.6828

217/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5504 - loss: 0.6825

219/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5508 - loss: 0.6822

221/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5513 - loss: 0.6819

223/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5518 - loss: 0.6816

225/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5523 - loss: 0.6813

227/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5528 - loss: 0.6810

229/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5534 - loss: 0.6807

231/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5539 - loss: 0.6803

233/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5544 - loss: 0.6800

235/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5549 - loss: 0.6796

237/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5554 - loss: 0.6793

239/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5559 - loss: 0.6790

241/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5564 - loss: 0.6786

243/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5569 - loss: 0.6783

245/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5574 - loss: 0.6780

247/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5579 - loss: 0.6776

249/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5584 - loss: 0.6773

251/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5589 - loss: 0.6770

253/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5594 - loss: 0.6766

255/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5599 - loss: 0.6763

257/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5604 - loss: 0.6759

259/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5609 - loss: 0.6756

261/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5614 - loss: 0.6752

263/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5619 - loss: 0.6748

265/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5625 - loss: 0.6745

267/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5630 - loss: 0.6741

269/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5635 - loss: 0.6738

271/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5640 - loss: 0.6734

273/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5645 - loss: 0.6730

275/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5650 - loss: 0.6727

277/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5655 - loss: 0.6723

279/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5660 - loss: 0.6719

281/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5665 - loss: 0.6715

283/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5670 - loss: 0.6711

285/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5675 - loss: 0.6708

287/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5680 - loss: 0.6704

289/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5686 - loss: 0.6700

291/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5691 - loss: 0.6696

293/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5696 - loss: 0.6692

295/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5701 - loss: 0.6688

297/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5706 - loss: 0.6684

299/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5711 - loss: 0.6680

301/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5716 - loss: 0.6676

303/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5722 - loss: 0.6672

305/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5727 - loss: 0.6668

307/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5732 - loss: 0.6664

309/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5737 - loss: 0.6660

311/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5742 - loss: 0.6655

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5748 - loss: 0.6651

313/313 ━━━━━━━━━━━━━━━━━━━━ 11s 32ms/step - accuracy: 0.5750 - loss: 0.6649 - val_accuracy: 0.7674 - val_loss: 0.5044
Epoch 2/2
  1/313 ━━━━━━━━━━━━━━━━━━━━ 12s 42ms/step - accuracy: 0.7969 - loss: 0.4343

  3/313 ━━━━━━━━━━━━━━━━━━━━ 9s 29ms/step - accuracy: 0.8021 - loss: 0.4295 

  5/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.7950 - loss: 0.4412

  7/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7950 - loss: 0.4420

  9/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7962 - loss: 0.4414

 11/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7985 - loss: 0.4395

 13/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7993 - loss: 0.4387

 15/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8011 - loss: 0.4367

 17/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8030 - loss: 0.4343

 19/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8045 - loss: 0.4324

 21/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8059 - loss: 0.4310

 23/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8071 - loss: 0.4299

 25/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8081 - loss: 0.4289

 27/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8090 - loss: 0.4279

 29/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8097 - loss: 0.4269

 31/313 ━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8102 - loss: 0.4262

 33/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8105 - loss: 0.4256

 35/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8108 - loss: 0.4252

 37/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8111 - loss: 0.4248

 39/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8113 - loss: 0.4243

 41/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8117 - loss: 0.4238

 43/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8120 - loss: 0.4232

 45/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8123 - loss: 0.4226

 47/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8125 - loss: 0.4221

 49/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8129 - loss: 0.4214

 51/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8132 - loss: 0.4207

 53/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8135 - loss: 0.4201

 55/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8138 - loss: 0.4195

 57/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8142 - loss: 0.4189

 59/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8145 - loss: 0.4182

 61/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8148 - loss: 0.4178

 63/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8149 - loss: 0.4175

 65/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8151 - loss: 0.4172

 67/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8153 - loss: 0.4169

 69/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8155 - loss: 0.4166

 71/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8157 - loss: 0.4162

 73/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8160 - loss: 0.4159

 75/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8162 - loss: 0.4155

 77/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8165 - loss: 0.4151

 79/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8168 - loss: 0.4148

 81/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8170 - loss: 0.4144

 83/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8173 - loss: 0.4141

 85/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8175 - loss: 0.4137

 87/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8177 - loss: 0.4135

 89/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8179 - loss: 0.4133

 91/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8180 - loss: 0.4133

 93/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8180 - loss: 0.4133

 95/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8180 - loss: 0.4135

 97/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8179 - loss: 0.4136

 99/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8178 - loss: 0.4138

101/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8177 - loss: 0.4140

103/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8176 - loss: 0.4142

105/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8175 - loss: 0.4144

107/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8174 - loss: 0.4146

109/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4148

111/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4149

113/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8172 - loss: 0.4151

115/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8172 - loss: 0.4152

117/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8172 - loss: 0.4153

119/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4154

121/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4155

123/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8174 - loss: 0.4156

125/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8174 - loss: 0.4156

127/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8175 - loss: 0.4157

129/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8175 - loss: 0.4157

131/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8176 - loss: 0.4157

133/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8177 - loss: 0.4157

135/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8178 - loss: 0.4157

137/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8178 - loss: 0.4156

139/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8179 - loss: 0.4156

141/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8180 - loss: 0.4155

143/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8181 - loss: 0.4154

145/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8182 - loss: 0.4153

147/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8183 - loss: 0.4152

149/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8185 - loss: 0.4151

151/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8186 - loss: 0.4150

153/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8187 - loss: 0.4148

155/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8188 - loss: 0.4147

157/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8190 - loss: 0.4145

159/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8191 - loss: 0.4143

161/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8192 - loss: 0.4141

163/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8194 - loss: 0.4139

165/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8195 - loss: 0.4137

167/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8197 - loss: 0.4135

169/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8198 - loss: 0.4133

171/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8200 - loss: 0.4131

173/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8201 - loss: 0.4129

175/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8203 - loss: 0.4126

177/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8205 - loss: 0.4124

179/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8206 - loss: 0.4121

181/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8208 - loss: 0.4119

183/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8209 - loss: 0.4117

185/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8211 - loss: 0.4115

187/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8213 - loss: 0.4112

189/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8214 - loss: 0.4110

191/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8216 - loss: 0.4107

193/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8218 - loss: 0.4105

195/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8219 - loss: 0.4102

197/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8221 - loss: 0.4099

199/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8223 - loss: 0.4097

201/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8224 - loss: 0.4094

203/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8226 - loss: 0.4091

205/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8228 - loss: 0.4089

207/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8229 - loss: 0.4086

209/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8231 - loss: 0.4083

211/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8233 - loss: 0.4080

213/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8234 - loss: 0.4077

215/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8236 - loss: 0.4074

217/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8238 - loss: 0.4072

219/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8240 - loss: 0.4069

221/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8241 - loss: 0.4066

223/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8243 - loss: 0.4063

225/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8245 - loss: 0.4060

227/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8247 - loss: 0.4057

229/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8248 - loss: 0.4054

231/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8250 - loss: 0.4051

233/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8252 - loss: 0.4048

235/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8254 - loss: 0.4045

237/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8255 - loss: 0.4042

239/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8257 - loss: 0.4039

241/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8259 - loss: 0.4036

243/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8261 - loss: 0.4033

245/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8262 - loss: 0.4031

247/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8264 - loss: 0.4028

249/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8265 - loss: 0.4025

251/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8267 - loss: 0.4022

253/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8269 - loss: 0.4019

255/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8270 - loss: 0.4017

257/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8272 - loss: 0.4014

259/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8273 - loss: 0.4011

261/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8275 - loss: 0.4008

263/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8277 - loss: 0.4005

265/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8278 - loss: 0.4003

267/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8280 - loss: 0.4000

269/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8281 - loss: 0.3997

271/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8283 - loss: 0.3994

273/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8285 - loss: 0.3991

275/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8286 - loss: 0.3989

277/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8288 - loss: 0.3986

279/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8289 - loss: 0.3983

281/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8291 - loss: 0.3980

283/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8293 - loss: 0.3977

285/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8294 - loss: 0.3975

287/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8296 - loss: 0.3972

289/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8297 - loss: 0.3969

291/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8299 - loss: 0.3967

293/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8300 - loss: 0.3964

295/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8302 - loss: 0.3961

297/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8303 - loss: 0.3959

299/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8305 - loss: 0.3956

301/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8306 - loss: 0.3954

303/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8307 - loss: 0.3951

305/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8309 - loss: 0.3948

307/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8310 - loss: 0.3946

309/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8312 - loss: 0.3943

311/313 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8313 - loss: 0.3941

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8314 - loss: 0.3938

313/313 ━━━━━━━━━━━━━━━━━━━━ 10s 31ms/step - accuracy: 0.8315 - loss: 0.3937 - val_accuracy: 0.8384 - val_loss: 0.3978
Test accuracy: 0.8351600170135498
units = 64                          # 4 × 64 = 256 bias values
gate_bias = np.concatenate([
    np.full(units,  2.0),           # input‑gate bias  (open)
    np.full(units,  1.0),           # forget‑gate bias (keep)
    np.zeros(units),                # cell candidate   (neutral)
    np.full(units,  2.0)            # output‑gate bias (open)  ← NEW
]).astype("float32")

good_lstm = Sequential([
    layers.Embedding(2, 4),
    layers.LSTM(
        units,
        unit_forget_bias=False,                 # we initialise all 4 gates ourselves
        bias_initializer=keras.initializers.Constant(gate_bias)
    ),
    layers.Dense(1, activation="sigmoid")
])
good_lstm.compile(
    tf.keras.optimizers.legacy.RMSprop(3e-3, clipnorm=1.0),
    losses.BinaryCrossentropy(),
    ["accuracy"],
)
good_lstm.fit(
    x_tr, y_tr,
    batch_size=256,
    epochs=20,
    validation_split=0.2,
    verbose=2
)
print("LSTM test acc:", good_lstm.evaluate(x_te, y_te, verbose=0)[1])
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 9
      1 units = 64                          # 4 × 64 = 256 bias values
      2 gate_bias = np.concatenate([
      3     np.full(units,  2.0),           # input‑gate bias  (open)
      4     np.full(units,  1.0),           # forget‑gate bias (keep)
      5     np.zeros(units),                # cell candidate   (neutral)
      6     np.full(units,  2.0)            # output‑gate bias (open)  ← NEW
      7 ]).astype("float32")
----> 9 good_lstm = Sequential([
     10     layers.Embedding(2, 4),
     11     layers.LSTM(
     12         units,
     13         unit_forget_bias=False,                 # we initialise all 4 gates ourselves
     14         bias_initializer=keras.initializers.Constant(gate_bias)
     15     ),
     16     layers.Dense(1, activation="sigmoid")
     17 ])
     18 good_lstm.compile(
     19     tf.keras.optimizers.legacy.RMSprop(3e-3, clipnorm=1.0),
     20     losses.BinaryCrossentropy(),
     21     ["accuracy"],
     22 )
     23 good_lstm.fit(
     24     x_tr, y_tr,
     25     batch_size=256,
   (...)     28     verbose=2
     29 )

NameError: name 'Sequential' is not defined
print("  LSTM:", good_lstm.evaluate(x_te, y_te, verbose=0)[1])