#native_company# #native_desc#

Understanding the MNIST training data

Let's get physical

Sometimes, nothing beats holding a copy of a book in your hands. Writing in the margins, highlighting sentences, folding corners. So this book is also available from Amazon as a paperback.

Buy now on Amazon

At the top of our start.js file you will see a function called loadData, like so:

async function loadData() {
  DATA = new MnistData();
  await DATA.load();

MnistData is the class that contains all the code which we use to gather and prepare the data for our MNIST application. This file is almost an exact mirror of a similar file in the official TensorFlow.js MNIST demo application. All the code for MnistData is in the data.js file. We won’t go through in-depth the contents of that file, but there are few critical points of knowledge it would be helpful to understand in more depth.

Features & Labels

The type of machine learning we are covering in this book is Supervised Machine Learning. This means we give our machine learning model a collection of training data (features) and what we expect the model to output for each set of that training data we pass in (labels). From just those two pieces of information, it learns how to give the correct answers (labels) for a given set of inputs (features).

Imagine studying for school with just the exam papers and solutions, no actual lessons. We do the same here; we are giving it the exam papers (features) and the exam solutions (labels), and the model then learns how to answer exam questions (assign the right label for a set of features). In the future, we can give it an exam paper (set of features) it has never seen before it will provide you with hopefully a pretty good answer (label).

Given that, the data we need to train our machine learning model is a set of images of hand-drawn digits (the features) as well as the actual digit drawn in the image (the label).


I’m being meticulous in using the words features and labels. That is the language used in all machine learning models and writing, get used to thinking in terms of the features of some data and the label you want to associate with those features.

Features, the hand-drawn digits

At the top of data.js we see two interesting variables:


MNIST_IMAGES_SPRITE_PATH is a PNG that contains all the images hand-drawn digits. You can open it up in your browser; it might take a second to load but should end up looking something like so:

mnist app features long image
Figure 1. What the MNIST feature data set looks like zoomed out

It’s a very long thin image, if you zoom into the image, you will see something like so:

mnist app features zoomed image
Figure 2. What the MNIST feature data set looks like zoomed in

It still doesn’t look like an image of a hand-drawn digit. The data is a single 28 by 28 pixel image of a digit spread like butter into one layer of 784 pixels. Each row in that image is, therefore, a single digit, and the entire image contains 65,000 digits. So the png file is a 784 x 65,000 pixel image.

This file gets loaded by the load() function, eventually after some processing all the data gets stored in a variable called datasetImages like so:

this.datasetImages = new Float32Array(datasetBytesBuffer);

The critical thing to realize is that even though the source is a 2D image, the datasetImages variable is that it’s a single dimensional array with all the image data. One big long array of 784 x 65,000 = 50,960,000 numbers, which, when added up, equals about 50mb of memory.

If we then look further down that file to getTrainData() we can then see how we use that Tensor shape parameters to load up the data as a 1D source and provide the shape, like so:

const xs = tf.tensor4d(this.trainImages, [ (1)
  this.trainImages.length / IMAGE_SIZE, (2)
  IMAGE_H, (3)
  IMAGE_W, (3)
  1 (4)
1 The source of all the image data, as a 1D array
2 The number of images
3 The height and width of the images (IMAGE_H and IMAGE_W are 28)
4 These are black and white images, so only one is needed for each pixel. If these were color images, we might have three here for each of Red, Green, and Blue.


=== It’s far easier in TensorFlow to store your data as a 1D source and then use the shapes parameter when creating tensors to add dimensionality. ===

Labels, the numerical digit of each image

What number is each image though? That’s what the file in MNIST_LABELS_PATH contains. I wouldn’t open it up, however since it’s a binary file and will show junk on the page!

It contains 65,000 entries, one for each image, but the format might seem a little strange at first, it uses something called one hot encoding.

Instead of using the numerical value of 9 to represent an image of 9. It represents the number as an array of 10 numbers, but the numbers can only be 0 or 1, like so:

Figure 3. 9 represented as a one-hot encoding format

This isn’t binary, only one of the elements can be 1 all the others have to be 0, the number it represents is based on where the 1 is in the array, like so:

The number 9 is represented as [0,0,0,0,0,0,0,0,0,1] The number 8 is represented as [0,0,0,0,0,0,0,0,1,0] The number 7 is represented as [0,0,0,0,0,0,0,1,0,0]

We have this approach because we are going to use a loss function called categoricalCrossEntropy.


A full dissection of categorical cross-entropy is beyond the scope of this introductory book. However, I can give a high-level overview, just like a car, you can achieve a lot only by understanding how to use it rather than the details of how it works.

In our first demo application in chapter two, we discussed Classification Outputs when using the MobileNet model. With classification type problems, models output a set of probabilities, and it’s up to you to decide if the probability is good enough to take that answer.

For the MNIST application, it’s the same; it’s going to output a set of probabilities for each of the ten possible labels, like so:

Figure 4. The output of the MNIST model as a probability distribution

This results in an output that looks like a probability distribution; we see that the model predicted a 78% probability that the number in the image is 9.

categoricalCrossEntropy is a good algorithm for calculating loss functions for probability distributions. If given two two probability distributions, it calculates how far off they are from each other.

When we use [0,0,0,0,0,0,0,0,0,1] to represent the numerical digit 9 we are saying that there is a 100% probability that it’s a 9 and 0% probability it’s any other number, one-hot encoding in this use case is a probability distribution.

So when you are building a classifier model, something that can classify things in buckets, the output is usually going to be a probability distribution, so the labels for the training data will need to be another probability distribution called one-hot encoding.

Taking a look at the code in the load() function, we can see our labels are created as another 1D array, like so:

this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

This array will be 65,000 rows, each containing a one-hot encoding representation of a digit so that it will be size 65,000 * 10.

Later on in our getTrainData() function we can see the labels are loaded into a Tensor like so:

const labels = tf.tensor2d(this.trainLabels, [ (1)
  this.trainLabels.length / NUM_CLASSES, (2)
1 The source data for our labels,
2 The first param of our shape is the number of records, this calculates down to 65,000
3 The second param of our shape is the number of “things” our model can predict, the size of our probability distribution, and the corresponding one-hot encoded labels, NUM_CLASSES is 10 in this case.

Test/Train Split

One thing to note in the file is that we have two functions to get formatted data, one called getTrainData() and one called getTestData().

When it comes to supervised machine learning, it’s good to keep some of your data back to validate that the model will work well in the real world and is not over fit. We call the data we will train our model with training data and the data we are going to validate our model with, test data.

So we train with the training data then once we think training is complete, we run it through some test data to validate, we don’t train the model with this test data, we are just double checking that if we give it some data it’s never seen before, it will still function well. If the accuracy during training gets to 95% and when running on fresh test data it drops to say 60%, we can conclude that the model has been overfitted. It understands the training data so intimately that it’s learned to make good guesses. However, when given some data it has never seen before, it fails.

We’ve configured at the top of the file a variable NUM_TRAIN_ELEMENTS, which is equal to 55000.

In our example, we reserved 10,000 of the 65,000 total examples for testing, and we are going to be training with the other 55,000.


We spent a long time discussing the source data for our example application, and we’ve now starting to see Tensors used in more real-world scenarios. Keeping the source data as a 1D array and applying shapes via Tensors to give the data more meaning.

We also talked about how these simple classification models output probability distributions. To train, we need to use the categoricalCrossEntropy algorithm, which in turn requires the labels for the training data set to be probability distributions, which is why you’ll see one-hot encoding used in input data sets, like ours.

Finally, we spoke about the test/train data split, how we reserve some data to be used only for validating our model might work in the real work and has not been over-fit.

We will use several different model algorithms and architectures in our example application, but all the training data will remain the same.

This is going to be your journey into Machine Learning, get a good source of data, make it clean, and structure it thoroughly. Then, you can try out several different Machine Learning architectures and settings until you get something that works for you.

Advanced JavaScript

This unique course teaches you advanced JavaScript knowledge through a series of interview questions. Bring your JavaScript to the 2021's today.

Level up your JavaScript now!