#native_company# #native_desc#
#native_cta#

Transfer Learning with MobileNet and KNN

Let's get physical

Sometimes, nothing beats holding a copy of a book in your hands. Writing in the margins, highlighting sentences, folding corners. So this book is also available from Amazon as a paperback.

Buy now on Amazon

In this lecture, we will use Transfer learning to retrain the MobileNet model to recognize hand signs (or any other images) instead.

The surprising thing for everyone I hope is just how easy this is going to be; the whole app is less than 100 lines of JavaScript. If I had asked you just to write those 30 lines then you would have had a practical application but not understood anything about how it works.

Now you’ll be able to recognize all the power in your fingertips with the APIs and Libraries you are using, and you’ll also be able to tweak this example or build others just as powerful all by yourself.

Important

For this application, you will need a web camera.

Emoji Trainer

The application we are going to build is called “Emoji Trainer”. You can teach it to associate images with certain emoji. E.g., you can train it to show a thumbs up emoji when you put your thumb up to the camera, like so:

Using the Emoji Trainer with thumbs up
Figure 1. Using the Emoji Trainer with thumbs up

Or you can train it to recognize a one-handed heart symbol like so:

Using the Emoji Trainer with heart
Figure 2. Using the Emoji Trainer with heart

Code

Open the emoji-trainer folder in the sources project, and let’s take a quick look at what you will find.

.
├── README.md
├── assets/
├── completed.js
├── index.html
└── start.js

assets contains some CSS files needed to style to application. completed.js is the final code; you can compare what you create with this file to see if you missed anything. start.js is the starting file; we will be fleshing this in this lecture. index.html this is where we load the dependant javascript files and packages.

index.html

If you open index.html you’ll find we import a number of dependencies, like so:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

First, we load the TensorFlow JS library, and then we load the MobileNet model since we’ll be using that also as a base, and finally, we load the KNN classifier package, which we will use to retrain the final layer.

The rest of the HTML file are UI controls for the application itself.

start.js

If you open start.js, you will find some boilerplate code, and we will be fleshing out all the sections marked // TODO.

At the top of our file you will see some global variables, like so:

// Global Variables
let KNN = null;
let MBNET = null;
let CAMERA = null;

// Config
const TOPK = 10;

KNN and MBNET reference where we will store the models. CAMERA is a reference to the webcam we will be using in the app. TOPK is the K factor we discussed in the previous lecture, we’ve set it to 10, but you may want to tweak this to a number that works for you.

main

At the bottom of the file, you will see a main function, and we will begin there.

At the top of the main function we need to load both our MobileNet and KNN classifier so we can use them in our application like so:

// Setup Models
console.log("Models Loading...")
KNN = knnClassifier.create(); (1)
MBNET = await mobilenet.load(); (2)
console.log("Models Loaded")
1 We create an instance of the knnClassifier and store it in our KNN variable. knnClassifier isn’t a model. It’s more of a utility package that handles the creation of a knnClassifier for you, i.e., there is no data for the knnClassifier to load up over the network, which is why there is no await keyword used here.
2 We load up the MobileNet model, we use await since this makes a call over the internet for the dependant data, we will wait for that to load before continuing with the application.

Next, in our main function, we want to initialize our camera. We can do this using standard HTML browser APIs, but it’s a little long-winded so the TensorFlow team created a simple one liner in the TensorFlow.js library that handles all of this for us, like so:

// Setup WebCam
let videoElement = document.getElementById("webcam");
CAMERA = await tf.data.webcam(videoElement); (1)
1 This is the line we need to add, this sets up the camera and makes it available to us via the CAMERA variable.

Training

To train our model, we make sure the camera is looking at the hand sign we want to associate with an emoji, and then we press the RECORD button next to the emoji, like so:

Pressing record grabs snapshots every 100ms to use as training data
Figure 3. Pressing record grabs snapshots every 100ms to use as training data

Every 100ms the RECORD button is pressed; it takes a snapshot of the camera and uses it to train the model, the total number of snapshots are printed to the right of the RECORD button.

Note

Do some experimentation with the number of snapshots required to give good predictions. You’ll be surprised at how good a model you can create with a minimal training data set. This is one of the critical advantages of Transfer Learning.

The code which grabs the snapshot and trains our model can be found in the setupButton function in start.js, so let’s flesh that out.

The setupButton function comes already with a fair amount of code attached, like so:

function setupButton(id) {
    let timerId = null;
    let btn = document.getElementById(id + "-btn")
    let span = document.getElementById(id + "-text")
    let input = document.getElementById(id + "-input")

    btn.addEventListener("mousedown", () => {
        let text = input.value;
        let count = 0;
        timerId = setInterval(() => {
            // TODO

            console.log(count)
            span.innerText = count;
            count++;
        }, 100)
    })
    btn.addEventListener("mouseup", () => {
        // Stop grabbing samples of images
        clearTimeout(timerId);
    });
}

This code is a lot of boilerplate, which grabs specific values from the HTML like the text inside the input field next to the button; this is what the text variable contains.

Tip

This emoji you are training the model to recognize are just values of an input field. I’ve defaulted them with my emoji; you can change to whatever emoji you want, or even only text.

It then listens for mouse down on the record button and starts an interval timer, which calls some code every 100ms. We will be fleshing out the // TODO in this function, but it also does some other useful things like incrementing the number displayed to the right of the record button.

The mouseup listener cancels the timer, so the inner training function doesn’t get called anymore.

Now le’s flesh out the // TODO, replace it with this code:

// Start grabbing an image of the video
const image = await CAMERA.capture(); (1)
// Pump it through mobilenet and get the logits
const logits = MBNET.infer(image, true); (2)
// Add this as a bit of data for knn
KNN.addExample(logits, text); (3)
1 We grab a snapshot from the webcam.
2 We first need to pump it through a decapitated MobileNet model, more on that in the next section. infer in this case returns an array of 1024 numbers which we store in logits (We’ll explain logits soon).
3 This adds a single data point to our KNN classifier, logits, in this case, is a 1024 array set of numbers. This is a 1024 dimensional space, and those 1024 numbers define a point in that 1024 dimensional space. text is the label to associate with that point.

In those three lines of code, we have implemented transfer learning!

Before we leave this function let’s not forget to dispose of our Tensors so we don’t have a memory leak. After the last line in the snippet above add this:

// Delete memory
image.dispose();
if (logits != null) {
    logits.dispose();
}

Decapitated Model

The critical line in the above piece of code is is MBNET.infer(image, true) so let’s spend a moment to unpack it.

MBNET takes as input an image and outputs an array of 1000 numbers, which indicate what the image might contain.

Each index in that array is associated with a “class”, a thing the image might be. The numbers in the array are related to the probability of that class being what’s in the image. If the number is high for that index, the model thinks that class is what is in the image.

If you used MBNET.infer(image) all by itself, it would return the values for that last layer, the 1000 length array of probabilities.

Passing true as the second parameter returns the output numbers for the layer just before the last layer.

This is usually called the logits layer, the last neuron layer of neural network for classification task which produces raw prediction values as real numbers ranging from [-infinity, +infinity ], see more on Wikipedia about Logit[1].

In our case, this is a layer that outputs 1024 numbers, ranging from -infinity to +infinity.

Even though 1024 is much higher than the 2 or 4 point examples we used in the KNN lecture, the logic is still the same. Instead of 2D space, we now have a 1024D space. Each set of 1024 numbers is a point in this 1024D space, and you can find the distance between any two points in this space using the euclidean distance algorithm, so KNN works!

We are stripping out the last layer in the MobileNet model (decapitating the model) and taking the outputs of the new output layer and using that as inputs into a new model. It could be another neural network model, but we are using a KNN model.

We are not retraining the MobileNet model, it is read-only for us. We are just getting the logit layer outputs and using them as inputs to a KNN classifier, which we are training.

Using the new combined MobileNet/KNN model

So now we’ve trained our new model, let’s use it. In our application, we click the RUN button when we are ready for the application to start making predictions based on the camera input. When the RUN button is pressed in the application, the run function is called in start.js which should currently look something like so:

async function run() {

    let output = document.getElementById("output-result")

    setInterval(async function predict() {
        // TODO
    }, 100)
}

Every 100ms it calls the function predict above, which we will flesh out to do some prediction. Paste the below code into the body of the predict function:

const numClasses = KNN.getNumClasses(); (1)
if (numClasses > 0) {
    const image = await CAMERA.capture(); (2)
    let logits = MBNET.infer(image, true); (3)
    const res = await KNN.predictClass(logits, TOPK); (4)
    console.log(res)
    output.innerText = res.label; (5)
    // Delete memory (6)
    image.dispose();
    if (logits != null) {
        logits.dispose();
    }
}
1 getNumClasses returns the number of labels the KNN classifier has been given, this should return 3 for our application, it will return 0 if we have not started the training process.
2 We grab a snapshot from the camera.
3 We get the outputs from the decapitated MobileNet model.
4 predictClass predicts the class of this new image given the examples we have already trained the KNN classifier with, i.e., it tells us which emoji it thinks the image is associated with.
5 This prints the new label (the emoji) to the screen on the application.
6 We remember to free up the memory.

Summary

We’ve covered a lot in this lecture, and the most crucial part was understanding what a decapitated model it and how to use it in transfer learning.

Try it out, and you’ll see some incredibly powerful application features with a relatively small amount of code.

Machine Learning typically takes considerable computation, so at first glance, browser-based or even JavaScript-based machine learning seems questionable. With this transfer learning example, I hope you’ve found a use case where machine learning in the browser appears feasible and very possible. Let your minds go wild and explore, take any other machine learning model, and use transfer learning to retrain it to do something related. It’s possible to build some incredibly useful applications with something like that.



Advanced JavaScript

This unique course teaches you advanced JavaScript knowledge through a series of interview questions. Bring your JavaScript to the 2021's today.

Level up your JavaScript now!