#native_company# #native_desc#
#native_cta#

Training a Neural Network using the Layers API

Let's get physical

Sometimes, nothing beats holding a copy of a book in your hands. Writing in the margins, highlighting sentences, folding corners. So this book is also available from Amazon as a paperback.

Buy now on Amazon

In the previous lecture, we covered how to create a straightforward model for our neural network, in this lecture we are going to pull that together with our training data to train it and do some machine learning.

Code

We’ve covered training concepts in previous lectures, so in this lecture, we will show how to implement those concepts with code and again using the TensorFlow Layers API.

We’ll be fleshing out the trainModel() function in start.js

The first thing we need to add to the top of the function is to get a reference to the model we created in the last lecture, like so:

MODEL = createDenseModel();

We are storing this in a global variable called MODEL since we need to use that in other functions in the file.

Warning

=== Global variables are not a great example of how to architect an application. I am using them to keep the supporting code simple so we can focus on machine learning. ===

Compiling the model

Before we start training a model, we need to compile it. compile[1] does several things, but mostly this is where we let TensorFlow know the type of optimizer we are going to use and the loss function.

Next up in our file add this code:

  MODEL.compile({
    optimizer: "rmsprop", (1)
    loss: "categoricalCrossentropy", (2)
    metrics: ["accuracy"] (3)
  });
1 Previously we’ve used the sgd optimizer, for our MNIST example we are going to use the rmsprop optimizer since it performs better with problems like these, for more details about how this optimizer works reads the paper linked to the rmsprop docs[2]. You can pass in a string such as "rmsprop" or pass in an instance like new tf.train.rmsprop(…​) You can fine-tune the optimizer by passing it values in its constructor; for our example, we are happy to use the defaults.
2 The loss function we are going to use is called categoricalCrossentropy. For the layers API, there is a list of Training Losses[3] you can use, "categoricalCrossentropy" is the string version of softMaxCrossEtropy[4].
3 The layers API gives us a way to interrogate the training and useful metrics after each run, and we provide a list of the types of metrics we are interested in here.

Training the model

Before we train the model, we need the data, to get a reference to the training data from our MnistData object, add this line of code:

const trainData = DATA.getTrainData();

trainData is an object with two properties, xs are the features, the images we are going to train with, and 'labels' are the one hot encoded versions of the digits in each corresponding image.

Next up, in between the "Training Start" and "Training Complete" log lines add this code:

  await MODEL.fit(trainData.xs, trainData.labels, { (1)
    batchSize: BATCH_SIZE, (2)
    validationSplit: VALIDATION_SPLIT, (3)
    epochs: EPOCHS (4)
  });
1 The fit function is the primary function that performs the training of our model, we pass it the training data.
2 The amount of data in our training set, 55000, is substantial. If we tried to train in one attempt, we would run out of memory on our computer. Instead, we break up the data into smaller batches and run through one batch at a time. BATCH_SIZE in our application is set to 320, so it will train with 320 example images at a time.
3 When breaking up your training into multiple smaller batches, it’s good practice to keep some data behind as a validation that the batching process itself hasn’t caused issues with the training. It’s the same reason we keep back some of the data for testing once the model has completed training. VALIDATION_SPLIT is set to 0.15, which means 15% of the 320 examples are used to validate that the batching isn’t causing problems.
4 epochs is the number of times we want to loop through all training data, the number of iterations. EPOCHS is currently set in the file as 1, but feel free to experiment and set it to a higher amount.

Testing the model

After we have completed the training, we want to perform a final test of the model with data it has never seen before. This gives us an indication of how the model will work in the real world with real data.

We add this code to the end of the trainModel() function:

const testData = DATA.getTestData(); (1)
const testResult = MODEL.evaluate(testData.xs, testData.labels); (2)
const testAccPercent = testResult[1].dataSync()[0] * 100; (3)
console.log(`Final test accuracy: ${testAccPercent.toFixed(1)}%`); (4)
1 We get the test data from our MnistData class, same as the training data, this contains xs and labels.
2 The evaluate function runs those inputs through the model and the loss function but importantly doesn’t perform any training. We want to figure out how accurate the model is.
3 This is some simple formatting to get data from testResult and turn it into an accuracy percentage.
4 We print this out to the console.

Running the application

If we open up the application, make sure to open up the console, press the train button, and see some information printed out like so:

Training Model
🎉 Training Start
🍾 Training Complete
Final test accuracy: 90.1%

We now have a trained model with a test accuracy of 90.1%, pretty accurate for one epoch, and a very naive densely connected neural network.

Interrogating the training

The UI for our application didn’t change at all. We pressed LOAD, and the only way to know what was happening in training was to wait for it to end and see what is logged to the console. We want to see what’s happening during training and perhaps update the UI accordingly, and TensorFlow.js has something that we can use called callbacks.

If we go back to our fit function and just after epochs properly let’s add another called callbacks with code like so:

await MODEL.fit(trainData.xs, trainData.labels, {
    batchSize: BATCH_SIZE,
    validationSplit: VALIDATION_SPLIT,
    epochs: EPOCHS,
    callbacks: {
      onBatchEnd: async (batch, logs) => { (1)
        trainBatchCount++;
        let percentComplete = (
          (trainBatchCount / totalNumBatches) *
          100
        ).toFixed(1);
        PROGRESS_UI.setProgress(percentComplete);
        PROGRESS_UI.setStatus(`ACC ${logs.acc.toFixed(3)}`); (2)
        console.log(`Training... (${percentComplete}% complete)`);
        await tf.nextFrame(); (3)
      },
      onEpochEnd: async (epoch, logs) => { (4)
        valAcc = logs.val_acc;
        console.log(`Accuracy: ${valAcc}`);
        PROGRESS_UI.setStatus(`*ACC ${logs.val_acc.toFixed(3)}`);
        await tf.nextFrame();
      }
    }
});
1 The onBatchEnd callback gets called at the end of each smaller 320 sized batch of training. It gets passed some useful parameters like logs, which contains numbers like the current training accuracy of the model and others.
2 We turn the accuracy into a percentage and use some helper UI code, so it’s shown on the screen on the top right of our application.
3 We use await tf.nextFrame() so the browser has a chance to pause and draw the progress on the page before continuing with the training.
4 onEpochEnd is very similar, it gets called at the end of an entire epoch. Once all the data has been looped through once.

When you press the LOAD button in the application, you should see a progress bar appear to the right of the button at the top of the page and visually see the application’s progress as it goes through the training process.

Summary

We learned how to perform training using the layers API, how to use the compile function to prepare our model with the optimizer and loss function we are going to use, and how to use the fit function to perform the training. We later learned how to use callbacks to interrogate the model during the training process.

Next up, we are going to learn how to use the model in a real-world setting. We are going to draw a number and have the model try to predict what number we drew.



Advanced JavaScript

This unique course teaches you advanced JavaScript knowledge through a series of interview questions. Bring your JavaScript to the 2021's today.

Level up your JavaScript now!