(rnn =)
Recurrent Neural Networks (RNNs)#
A Recurrent Neural Network (RNN) is a neural architecture designed for sequential data. Unlike feedforward networks, it contains recurrent connections in the hidden layer: the hidden state \(h_t\) at time \(t\) depends on both the current input \(x_t\) and the previous hidden state \(h_{t-1}\). Because the same weights are reused at every step, the model can, in principle, handle sequences of arbitrary length without increasing the parameter count.
Components of an RNN#
Input layer – receives a sequence of vectors \((x_1,\dots,x_T)\), e.g., word embeddings.
Recurrent (hidden) layer(s) – update rule below; weights are shared across all \(t\).
Output layer – either emits a prediction \(y_t\) at each step, or reads the final state \(h_T\) to classify the whole sequence.
Recurrent connection – the link \(h_{t-1}\rightarrow h_t\) that carries memory forward.
Update Rule#
where:
\(x_t\in\mathbb{R}^{d_{\text{in}}}\) – input at time \(t\).
\(h_t\in\mathbb{R}^{d_h}\) – hidden state (memory).
\(y_t\) – output; in sequence classification we often set \(y_t = y\) only for \(t=T\).
\(W_{xh}\), \(W_{hh}\), \(W_{hy}\) and biases \(b_h\), \(b_y\) – shared, learnable parameters.
\(\sigma\) – typically \(\tanh\) or \(\text{ReLU}\).
\(g\) – task‑dependent: softmax (language modelling), sigmoid (binary label), or identity (feature for another layer).
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import seaborn as sns
2025-05-08 15:05:46.107189: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:05:46.110437: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:05:46.118949: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746716746.133132 71957 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746716746.137251 71957 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746716746.148819 71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716746.148831 71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716746.148833 71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716746.148834 71957 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-05-08 15:05:46.153017: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
tf.random.set_seed(1); np.random.seed(1)
# ==========================================================
# 1. Synthetic data: sine wave that changes frequency every 400 steps
# + a little noise
# ==========================================================
SEQ_LEN = 50
PRED_HORIZON = 50
N_STEPS = 200
def regime_wave(n, switch_every=100, noise=0.05):
f1, f2 = 0.20, 0.05
wave = [np.sin((f1 if (i//switch_every)%2==0 else f2) * i) for i in range(n)]
return np.array(wave) #+ noise*np.random.randn(n)
series = regime_wave(N_STEPS)
plt.figure(figsize=(12, 4))
plt.plot(series, label='series')
plt.title('Synthetic data: sine wave with changing frequency')
plt.xlabel('Time step')
plt.ylabel('Value')
Text(0, 0.5, 'Value')
# ==========================================================
# 0. Imports and seeds
# ==========================================================
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
tf.random.set_seed(1); np.random.seed(1)
# ==========================================================
# 1. Synthetic data: sine wave that changes frequency every 400 steps
# + a little noise
# ==========================================================
SEQ_LEN = 50
PRED_HORIZON = 50
N_STEPS = 4_000
def regime_wave(n, switch_every=400, noise=0.05):
f1, f2 = 0.20, 0.05
wave = [np.sin((f1 if (i//switch_every)%2==0 else f2) * i) for i in range(n)]
return np.array(wave) + noise*np.random.randn(n)
series = regime_wave(N_STEPS)
def make_windows(data, win, horiz):
X, y = [], []
for start in range(len(data) - win - horiz):
X.append(data[start:start+win])
y.append(data[start+win:start+win+horiz])
return np.array(X)[..., None], np.array(y)[..., None]
X, y = make_windows(series, SEQ_LEN, PRED_HORIZON) # (3900, 50, 1) (3900, 50, 1)
cut = 2800 # first 2800 windows for training
X_tr, y_tr = X[:cut], y[:cut]
X_val, y_val = X[cut:], y[cut:]
y_tr_vec = y_tr.squeeze(-1) # Dense/Conv need (batch, 50)
y_val_vec = y_val.squeeze(-1)
# ==========================================================
# 2. Three architectures: Dense, Conv1D, SimpleRNN
# ==========================================================
def dense_model():
m = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(SEQ_LEN,1)),
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dense(PRED_HORIZON)
])
m.compile("adam", "mse")
return m
def conv_model():
m = tf.keras.Sequential([
tf.keras.layers.Conv1D(64, 5, activation="relu", input_shape=(SEQ_LEN,1)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(PRED_HORIZON)
])
m.compile("adam", "mse")
return m
def rnn_model():
m = tf.keras.Sequential([
tf.keras.layers.SimpleRNN(64, input_shape=(SEQ_LEN,1)),
tf.keras.layers.Dense(PRED_HORIZON)
])
m.compile("adam", "mse")
return m
models = dict(Dense=dense_model(), Conv1D=conv_model(), RNN=rnn_model())
hists = {}
# ==========================================================
# 3. Train all three
# ==========================================================
for name, mdl in models.items():
print(f"Training {name:5s}")
h = mdl.fit(X_tr,
y_tr_vec if name!="RNN" else y_tr_vec,
validation_data=(X_val, y_val_vec),
epochs=40, batch_size=64, verbose=0)
hists[name] = h.history["val_loss"]
print(f" final val MSE: {h.history['val_loss'][-1]:.4f}")
# ==========================================================
# 4. Plot validation loss (log scale)
# ==========================================================
plt.figure(figsize=(6,4))
for name, losses in hists.items():
plt.plot(losses, label=name)
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("val MSE (log)")
plt.title("SimpleRNN outperforms Dense & Conv1D")
plt.legend()
plt.show()
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(**kwargs)
2025-05-08 15:05:48.792282: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/keras/src/layers/rnn/rnn.py:200: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(**kwargs)
Training Dense
final val MSE: 0.0856
Training Conv1D
final val MSE: 0.0804
Training RNN
final val MSE: 0.0715
# ==========================================================
# 6. Visualise one forecast window
# ==========================================================
def forecast_trace(model, X_one, y_true, label, color):
past = X_one.squeeze() # 50‑step context
truth = y_true.squeeze() # 50‑step ground truth
pred = model.predict(X_one[None, ...], verbose=0).squeeze()
t_past = np.arange(len(past))
t_fut = np.arange(len(past), len(past)+len(truth))
plt.plot(t_past, past, color="black")
plt.plot(t_fut, truth, marker="o", color="green",
label="ground truth" if label=="Dense" else None)
plt.plot(t_fut, pred, marker="x", color=color, label=label)
# Pick any validation example
idx = 100
plt.figure(figsize=(8,3))
forecast_trace(models["Dense"], X_val[idx], y_val[idx], "Dense", "red")
forecast_trace(models["Conv1D"], X_val[idx], y_val[idx], "Conv1D", "orange")
forecast_trace(models["RNN"], X_val[idx], y_val[idx], "RNN", "blue")
plt.title("Next-50-step forecast on validation window")
plt.legend()
sns.despine()
plt.show()
Example - Sentiment Analysis of Movie Reviews#
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# ---------------- data ----------------
vocab_size = 10_000
maxlen = 200 # trim / pad reviews to 200 tokens
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=vocab_size)
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen, padding="post")
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen, padding="post")
# ---------------- model ----------------
embed_dim = 64
hidden_units = 32
model = keras.Sequential([
layers.Embedding(vocab_size, embed_dim, mask_zero=True), # word → vector
layers.SimpleRNN(hidden_units), # processes the sequence
layers.Dense(1, activation="sigmoid") # one sentiment score
])
model.compile(optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"])
model.fit(x_train, y_train,
epochs=2,
batch_size=64,
validation_split=0.2)
print("Test accuracy:", model.evaluate(x_test, y_test, verbose=0)[1])
Epoch 1/2
1/313 ━━━━━━━━━━━━━━━━━━━━ 7:12 1s/step - accuracy: 0.4375 - loss: 0.7079
3/313 ━━━━━━━━━━━━━━━━━━━━ 9s 31ms/step - accuracy: 0.4540 - loss: 0.7042
5/313 ━━━━━━━━━━━━━━━━━━━━ 9s 31ms/step - accuracy: 0.4599 - loss: 0.7052
7/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.4616 - loss: 0.7056
9/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.4693 - loss: 0.7044
11/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.4747 - loss: 0.7034
13/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4795 - loss: 0.7025
15/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4849 - loss: 0.7016
17/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4893 - loss: 0.7008
19/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4931 - loss: 0.7001
21/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4963 - loss: 0.6994
23/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.4989 - loss: 0.6990
25/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5012 - loss: 0.6986
27/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5032 - loss: 0.6982
29/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5049 - loss: 0.6979
31/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5063 - loss: 0.6976
33/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5075 - loss: 0.6974
35/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5086 - loss: 0.6971
37/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5093 - loss: 0.6969
39/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5101 - loss: 0.6967
41/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5108 - loss: 0.6965
43/313 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - accuracy: 0.5114 - loss: 0.6963
45/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5119 - loss: 0.6961
47/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5125 - loss: 0.6960
49/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5132 - loss: 0.6958
51/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5138 - loss: 0.6956
53/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5145 - loss: 0.6954
55/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5151 - loss: 0.6952
57/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5158 - loss: 0.6950
59/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5164 - loss: 0.6949
61/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5170 - loss: 0.6947
63/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5176 - loss: 0.6945
65/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5183 - loss: 0.6943
67/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5190 - loss: 0.6941
69/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5198 - loss: 0.6939
71/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5204 - loss: 0.6937
73/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5211 - loss: 0.6936
75/313 ━━━━━━━━━━━━━━━━━━━━ 7s 30ms/step - accuracy: 0.5216 - loss: 0.6934
77/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5222 - loss: 0.6933
79/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5227 - loss: 0.6931
81/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5231 - loss: 0.6930
83/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5235 - loss: 0.6929
85/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5239 - loss: 0.6928
87/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5243 - loss: 0.6927
89/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5246 - loss: 0.6926
91/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5250 - loss: 0.6925
93/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5254 - loss: 0.6924
95/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5257 - loss: 0.6923
97/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5261 - loss: 0.6922
99/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5264 - loss: 0.6921
101/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5268 - loss: 0.6920
103/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5271 - loss: 0.6919
105/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5275 - loss: 0.6918
107/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5278 - loss: 0.6917
109/313 ━━━━━━━━━━━━━━━━━━━━ 6s 30ms/step - accuracy: 0.5282 - loss: 0.6915
111/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5285 - loss: 0.6914
113/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5289 - loss: 0.6913
115/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5293 - loss: 0.6912
117/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5297 - loss: 0.6911
119/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5300 - loss: 0.6910
121/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5304 - loss: 0.6909
123/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5307 - loss: 0.6908
125/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5310 - loss: 0.6907
127/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5314 - loss: 0.6906
129/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5317 - loss: 0.6905
131/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5321 - loss: 0.6904
133/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5325 - loss: 0.6902
135/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5328 - loss: 0.6901
137/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5332 - loss: 0.6900
139/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5336 - loss: 0.6899
141/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5340 - loss: 0.6898
143/313 ━━━━━━━━━━━━━━━━━━━━ 5s 30ms/step - accuracy: 0.5344 - loss: 0.6896
145/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5348 - loss: 0.6895
147/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5352 - loss: 0.6894
149/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5356 - loss: 0.6892
151/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5360 - loss: 0.6891
153/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5364 - loss: 0.6890
155/313 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step - accuracy: 0.5368 - loss: 0.6888
157/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5373 - loss: 0.6886
159/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5377 - loss: 0.6885
161/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5381 - loss: 0.6883
163/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5385 - loss: 0.6881
165/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5390 - loss: 0.6880
167/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5393 - loss: 0.6878
169/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5397 - loss: 0.6876
171/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5401 - loss: 0.6875
173/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5405 - loss: 0.6873
175/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5409 - loss: 0.6872
177/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.5413 - loss: 0.6870
179/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5417 - loss: 0.6868
181/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5421 - loss: 0.6866
183/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5425 - loss: 0.6865
185/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5429 - loss: 0.6863
187/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5433 - loss: 0.6861
189/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5438 - loss: 0.6859
191/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5442 - loss: 0.6857
193/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5446 - loss: 0.6855
195/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5451 - loss: 0.6853
197/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5455 - loss: 0.6851
199/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5460 - loss: 0.6848
201/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5465 - loss: 0.6846
203/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5469 - loss: 0.6844
205/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5474 - loss: 0.6841
207/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5479 - loss: 0.6838
209/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5484 - loss: 0.6836
211/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.5489 - loss: 0.6833
213/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5494 - loss: 0.6830
215/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5499 - loss: 0.6828
217/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5504 - loss: 0.6825
219/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5508 - loss: 0.6822
221/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5513 - loss: 0.6819
223/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5518 - loss: 0.6816
225/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5523 - loss: 0.6813
227/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5528 - loss: 0.6810
229/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5534 - loss: 0.6807
231/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5539 - loss: 0.6803
233/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5544 - loss: 0.6800
235/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5549 - loss: 0.6796
237/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5554 - loss: 0.6793
239/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5559 - loss: 0.6790
241/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5564 - loss: 0.6786
243/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5569 - loss: 0.6783
245/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.5574 - loss: 0.6780
247/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5579 - loss: 0.6776
249/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5584 - loss: 0.6773
251/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5589 - loss: 0.6770
253/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5594 - loss: 0.6766
255/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5599 - loss: 0.6763
257/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5604 - loss: 0.6759
259/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5609 - loss: 0.6756
261/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5614 - loss: 0.6752
263/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5619 - loss: 0.6748
265/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5625 - loss: 0.6745
267/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5630 - loss: 0.6741
269/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5635 - loss: 0.6738
271/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5640 - loss: 0.6734
273/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5645 - loss: 0.6730
275/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5650 - loss: 0.6727
277/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5655 - loss: 0.6723
279/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.5660 - loss: 0.6719
281/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5665 - loss: 0.6715
283/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5670 - loss: 0.6711
285/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5675 - loss: 0.6708
287/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5680 - loss: 0.6704
289/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5686 - loss: 0.6700
291/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5691 - loss: 0.6696
293/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5696 - loss: 0.6692
295/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5701 - loss: 0.6688
297/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5706 - loss: 0.6684
299/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5711 - loss: 0.6680
301/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5716 - loss: 0.6676
303/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5722 - loss: 0.6672
305/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5727 - loss: 0.6668
307/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5732 - loss: 0.6664
309/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5737 - loss: 0.6660
311/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5742 - loss: 0.6655
313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.5748 - loss: 0.6651
313/313 ━━━━━━━━━━━━━━━━━━━━ 11s 32ms/step - accuracy: 0.5750 - loss: 0.6649 - val_accuracy: 0.7674 - val_loss: 0.5044
Epoch 2/2
1/313 ━━━━━━━━━━━━━━━━━━━━ 12s 42ms/step - accuracy: 0.7969 - loss: 0.4343
3/313 ━━━━━━━━━━━━━━━━━━━━ 9s 29ms/step - accuracy: 0.8021 - loss: 0.4295
5/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.7950 - loss: 0.4412
7/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7950 - loss: 0.4420
9/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7962 - loss: 0.4414
11/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7985 - loss: 0.4395
13/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.7993 - loss: 0.4387
15/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8011 - loss: 0.4367
17/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8030 - loss: 0.4343
19/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8045 - loss: 0.4324
21/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8059 - loss: 0.4310
23/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8071 - loss: 0.4299
25/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8081 - loss: 0.4289
27/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8090 - loss: 0.4279
29/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8097 - loss: 0.4269
31/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8102 - loss: 0.4262
33/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8105 - loss: 0.4256
35/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8108 - loss: 0.4252
37/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8111 - loss: 0.4248
39/313 ━━━━━━━━━━━━━━━━━━━━ 8s 29ms/step - accuracy: 0.8113 - loss: 0.4243
41/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8117 - loss: 0.4238
43/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8120 - loss: 0.4232
45/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8123 - loss: 0.4226
47/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8125 - loss: 0.4221
49/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8129 - loss: 0.4214
51/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8132 - loss: 0.4207
53/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8135 - loss: 0.4201
55/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8138 - loss: 0.4195
57/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8142 - loss: 0.4189
59/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8145 - loss: 0.4182
61/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8148 - loss: 0.4178
63/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8149 - loss: 0.4175
65/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8151 - loss: 0.4172
67/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8153 - loss: 0.4169
69/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8155 - loss: 0.4166
71/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8157 - loss: 0.4162
73/313 ━━━━━━━━━━━━━━━━━━━━ 7s 29ms/step - accuracy: 0.8160 - loss: 0.4159
75/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8162 - loss: 0.4155
77/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8165 - loss: 0.4151
79/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8168 - loss: 0.4148
81/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8170 - loss: 0.4144
83/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8173 - loss: 0.4141
85/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8175 - loss: 0.4137
87/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8177 - loss: 0.4135
89/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8179 - loss: 0.4133
91/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8180 - loss: 0.4133
93/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8180 - loss: 0.4133
95/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8180 - loss: 0.4135
97/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8179 - loss: 0.4136
99/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8178 - loss: 0.4138
101/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8177 - loss: 0.4140
103/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8176 - loss: 0.4142
105/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8175 - loss: 0.4144
107/313 ━━━━━━━━━━━━━━━━━━━━ 6s 29ms/step - accuracy: 0.8174 - loss: 0.4146
109/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4148
111/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4149
113/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8172 - loss: 0.4151
115/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8172 - loss: 0.4152
117/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8172 - loss: 0.4153
119/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4154
121/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8173 - loss: 0.4155
123/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8174 - loss: 0.4156
125/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8174 - loss: 0.4156
127/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8175 - loss: 0.4157
129/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8175 - loss: 0.4157
131/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8176 - loss: 0.4157
133/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8177 - loss: 0.4157
135/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8178 - loss: 0.4157
137/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8178 - loss: 0.4156
139/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8179 - loss: 0.4156
141/313 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step - accuracy: 0.8180 - loss: 0.4155
143/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8181 - loss: 0.4154
145/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8182 - loss: 0.4153
147/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8183 - loss: 0.4152
149/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8185 - loss: 0.4151
151/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8186 - loss: 0.4150
153/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8187 - loss: 0.4148
155/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8188 - loss: 0.4147
157/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8190 - loss: 0.4145
159/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8191 - loss: 0.4143
161/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8192 - loss: 0.4141
163/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8194 - loss: 0.4139
165/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8195 - loss: 0.4137
167/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8197 - loss: 0.4135
169/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8198 - loss: 0.4133
171/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8200 - loss: 0.4131
173/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8201 - loss: 0.4129
175/313 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - accuracy: 0.8203 - loss: 0.4126
177/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8205 - loss: 0.4124
179/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8206 - loss: 0.4121
181/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8208 - loss: 0.4119
183/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8209 - loss: 0.4117
185/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8211 - loss: 0.4115
187/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8213 - loss: 0.4112
189/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8214 - loss: 0.4110
191/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8216 - loss: 0.4107
193/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8218 - loss: 0.4105
195/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8219 - loss: 0.4102
197/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8221 - loss: 0.4099
199/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8223 - loss: 0.4097
201/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8224 - loss: 0.4094
203/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8226 - loss: 0.4091
205/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8228 - loss: 0.4089
207/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8229 - loss: 0.4086
209/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8231 - loss: 0.4083
211/313 ━━━━━━━━━━━━━━━━━━━━ 3s 29ms/step - accuracy: 0.8233 - loss: 0.4080
213/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8234 - loss: 0.4077
215/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8236 - loss: 0.4074
217/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8238 - loss: 0.4072
219/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8240 - loss: 0.4069
221/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8241 - loss: 0.4066
223/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8243 - loss: 0.4063
225/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8245 - loss: 0.4060
227/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8247 - loss: 0.4057
229/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8248 - loss: 0.4054
231/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8250 - loss: 0.4051
233/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8252 - loss: 0.4048
235/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8254 - loss: 0.4045
237/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8255 - loss: 0.4042
239/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8257 - loss: 0.4039
241/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8259 - loss: 0.4036
243/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8261 - loss: 0.4033
245/313 ━━━━━━━━━━━━━━━━━━━━ 2s 29ms/step - accuracy: 0.8262 - loss: 0.4031
247/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8264 - loss: 0.4028
249/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8265 - loss: 0.4025
251/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8267 - loss: 0.4022
253/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8269 - loss: 0.4019
255/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8270 - loss: 0.4017
257/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8272 - loss: 0.4014
259/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8273 - loss: 0.4011
261/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8275 - loss: 0.4008
263/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8277 - loss: 0.4005
265/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8278 - loss: 0.4003
267/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8280 - loss: 0.4000
269/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8281 - loss: 0.3997
271/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8283 - loss: 0.3994
273/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8285 - loss: 0.3991
275/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8286 - loss: 0.3989
277/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8288 - loss: 0.3986
279/313 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.8289 - loss: 0.3983
281/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8291 - loss: 0.3980
283/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8293 - loss: 0.3977
285/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8294 - loss: 0.3975
287/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8296 - loss: 0.3972
289/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8297 - loss: 0.3969
291/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8299 - loss: 0.3967
293/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8300 - loss: 0.3964
295/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8302 - loss: 0.3961
297/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8303 - loss: 0.3959
299/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8305 - loss: 0.3956
301/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8306 - loss: 0.3954
303/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8307 - loss: 0.3951
305/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8309 - loss: 0.3948
307/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8310 - loss: 0.3946
309/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8312 - loss: 0.3943
311/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8313 - loss: 0.3941
313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.8314 - loss: 0.3938
313/313 ━━━━━━━━━━━━━━━━━━━━ 10s 31ms/step - accuracy: 0.8315 - loss: 0.3937 - val_accuracy: 0.8384 - val_loss: 0.3978
Test accuracy: 0.8351600170135498
units = 64 # 4 × 64 = 256 bias values
gate_bias = np.concatenate([
np.full(units, 2.0), # input‑gate bias (open)
np.full(units, 1.0), # forget‑gate bias (keep)
np.zeros(units), # cell candidate (neutral)
np.full(units, 2.0) # output‑gate bias (open) ← NEW
]).astype("float32")
good_lstm = Sequential([
layers.Embedding(2, 4),
layers.LSTM(
units,
unit_forget_bias=False, # we initialise all 4 gates ourselves
bias_initializer=keras.initializers.Constant(gate_bias)
),
layers.Dense(1, activation="sigmoid")
])
good_lstm.compile(
tf.keras.optimizers.legacy.RMSprop(3e-3, clipnorm=1.0),
losses.BinaryCrossentropy(),
["accuracy"],
)
good_lstm.fit(
x_tr, y_tr,
batch_size=256,
epochs=20,
validation_split=0.2,
verbose=2
)
print("LSTM test acc:", good_lstm.evaluate(x_te, y_te, verbose=0)[1])
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[6], line 9
1 units = 64 # 4 × 64 = 256 bias values
2 gate_bias = np.concatenate([
3 np.full(units, 2.0), # input‑gate bias (open)
4 np.full(units, 1.0), # forget‑gate bias (keep)
5 np.zeros(units), # cell candidate (neutral)
6 np.full(units, 2.0) # output‑gate bias (open) ← NEW
7 ]).astype("float32")
----> 9 good_lstm = Sequential([
10 layers.Embedding(2, 4),
11 layers.LSTM(
12 units,
13 unit_forget_bias=False, # we initialise all 4 gates ourselves
14 bias_initializer=keras.initializers.Constant(gate_bias)
15 ),
16 layers.Dense(1, activation="sigmoid")
17 ])
18 good_lstm.compile(
19 tf.keras.optimizers.legacy.RMSprop(3e-3, clipnorm=1.0),
20 losses.BinaryCrossentropy(),
21 ["accuracy"],
22 )
23 good_lstm.fit(
24 x_tr, y_tr,
25 batch_size=256,
(...) 28 verbose=2
29 )
NameError: name 'Sequential' is not defined
print(" LSTM:", good_lstm.evaluate(x_te, y_te, verbose=0)[1])