Open In Colab

Example: Classifying Digits#

This is in the initial example discussed in the book “Deep Learning with Python” by François Chollet.

The MNIST dataset is a set of handwritten digits. It has a training set of 60,000 examples, and a test set of 10,000 examples. The digits have been size-normalized and centered in a fixed-size image.

Step 1: Load the data#

from tensorflow.keras.datasets import mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
2025-05-08 15:02:19.494154: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:02:19.497426: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-08 15:02:19.506078: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746716539.520347   43194 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746716539.524444   43194 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746716539.535824   43194 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716539.535834   43194 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716539.535836   43194 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746716539.535837   43194 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-05-08 15:02:19.539736: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
       0/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step

 1982464/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

 8413184/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

train_data and test_data: lists of images of handwritten digits.

train_labels and test_labels: The label associated with an example is the digit 0-9.

digit_idx = 0
print(train_images[digit_idx])
print(train_labels[digit_idx])
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3  18  18  18 126 136
  175  26 166 255 247 127   0   0   0   0]
 [  0   0   0   0   0   0   0   0  30  36  94 154 170 253 253 253 253 253
  225 172 253 242 195  64   0   0   0   0]
 [  0   0   0   0   0   0   0  49 238 253 253 253 253 253 253 253 253 251
   93  82  82  56  39   0   0   0   0   0]
 [  0   0   0   0   0   0   0  18 219 253 253 253 253 253 198 182 247 241
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  80 156 107 253 253 205  11   0  43 154
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  14   1 154 253  90   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 139 253 190   2   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  11 190 253  70   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  35 241 225 160 108   1
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  81 240 253 253 119
   25   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  45 186 253 253
  150  27   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  16  93 252
  253 187   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 249
  253 249  64   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  46 130 183 253
  253 207   2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  39 148 229 253 253 253
  250 182   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  24 114 221 253 253 253 253 201
   78   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  23  66 213 253 253 253 253 198  81   2
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  18 171 219 253 253 253 253 195  80   9   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0  55 172 226 253 253 253 253 244 133  11   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0 136 253 253 253 212 135 132  16   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]]
5
# display images
import matplotlib.pyplot as plt
import numpy as np

def display_image(image):
    plt.imshow(np.reshape(image, (28,28)), cmap='gray')
    plt.axis('off')
    plt.show()

digit_idx = 10
display_image(train_images[digit_idx])
print(f"The label is {train_labels[digit_idx]}")
../_images/ab369d1cea1b8a291cb512c8f09acdfa2f6ec1b15be2250f8b784a61ce2e8951.png
The label is 3

Step 2: Preprocess the data#

We reshape the data into the shape that the network expects, and scale it so that all values are in the [0, 1] interval.

  • Reshaping flattens images so they can be inputted to standard neural networks.

  • Normalizing converts pixel values from [0,255] to [0,1] to aid neural network training.

  • Converting to float32 ensures the model works with precise fractional numbers.

train_images = train_images.reshape((60000, 28*28)) 
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32') / 255
print(len(train_images[0]))
784

Define the Model#

sparse_categorical_crossentropy is a loss function used for multi-class classification problems when the labels are integers rather than one-hot vectors.

It’s essentially identical to categorical_crossentropy, except it simplifies label handling by not requiring explicit one-hot encoding.

from tensorflow import keras 
from tensorflow.keras.layers import Dense

model = keras.Sequential([
    Dense(64, activation='relu'),
    Dense(256, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(
    optimizer='rmsprop', 
    loss = 'sparse_categorical_crossentropy', 
    metrics=["accuracy"])
2025-05-08 15:02:22.491998: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)

Step 4: Train the model#

We will train the model for 20 epochs in mini-batches of 512 samples.

history = model.fit(train_images, train_labels, epochs=10, batch_size=128, verbose=0)
import matplotlib.pyplot as plt 
import seaborn as sns

sns.set_style("whitegrid")
history_dict = history.history

loss_values = history_dict["loss"]

epochs = range(1, len(loss_values) + 1) 
plt.plot(epochs, loss_values, "o-", label="Training loss") 
plt.title("Training  loss") 
plt.xticks(epochs)
plt.xlabel("Epochs") 
plt.ylabel("Loss") 
plt.legend() 
sns.despine()
plt.grid(False)
plt.show()
../_images/f3bc7310ef6e847f5b330ce8a68664b36b948699df7c9525bafdde056c27d4ca.png
acc = history_dict["accuracy"]

epochs = range(1, len(loss_values) + 1) 
plt.plot(epochs, acc, "o-", label="Training acc")
plt.title("Training  accuracy") 
plt.xticks(epochs)
plt.xlabel("Epochs") 
plt.ylabel("Accuracy") 
plt.legend() 
sns.despine()
plt.grid(False)
plt.show()
../_images/cb89f51f42fd402d5b5bc155354c46eda89d2fc7b28c90d9a9a843b2ab2d45d4.png
results = model.evaluate(test_images, test_labels)

print(f"The test loss is {results[0]}")
print(f"The test accuracy is {results[1]}")
print("The predictions are:")   
predictions = model.predict(test_images)
print(predictions)
  1/313 ━━━━━━━━━━━━━━━━━━━━ 27s 90ms/step - accuracy: 1.0000 - loss: 0.0101

 53/313 ━━━━━━━━━━━━━━━━━━━━ 0s 967us/step - accuracy: 0.9772 - loss: 0.0797

108/313 ━━━━━━━━━━━━━━━━━━━━ 0s 943us/step - accuracy: 0.9735 - loss: 0.0987

163/313 ━━━━━━━━━━━━━━━━━━━━ 0s 933us/step - accuracy: 0.9724 - loss: 0.1024

219/313 ━━━━━━━━━━━━━━━━━━━━ 0s 926us/step - accuracy: 0.9726 - loss: 0.1016

274/313 ━━━━━━━━━━━━━━━━━━━━ 0s 925us/step - accuracy: 0.9734 - loss: 0.0985

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 962us/step - accuracy: 0.9740 - loss: 0.0962
The test loss is 0.08320607244968414
The test accuracy is 0.9775999784469604
The predictions are:
  1/313 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step

 82/313 ━━━━━━━━━━━━━━━━━━━━ 0s 619us/step

166/313 ━━━━━━━━━━━━━━━━━━━━ 0s 610us/step

245/313 ━━━━━━━━━━━━━━━━━━━━ 0s 619us/step

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 680us/step

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 712us/step
[[4.6618698e-12 5.4504061e-11 6.9359729e-08 ... 9.9999982e-01
  3.0681863e-10 3.1769345e-10]
 [3.6057573e-14 3.7328530e-07 9.9999958e-01 ... 3.5972413e-13
  3.9626080e-09 3.3924853e-19]
 [1.8018078e-08 9.9956882e-01 2.5045214e-05 ... 4.6468926e-05
  3.5494170e-04 9.1214609e-08]
 ...
 [3.6129566e-16 1.5118119e-12 4.7771649e-14 ... 1.0168957e-07
  5.1780430e-10 1.1110571e-05]
 [3.7461545e-09 4.9541122e-12 4.2986995e-13 ... 1.2946375e-11
  7.1850550e-06 1.8296942e-12]
 [4.5440652e-12 8.8377114e-14 5.1600729e-12 ... 2.6945746e-17
  7.1381329e-13 3.6100231e-15]]