How to use TensorFlow.js on browser using web worker

Detecting objects in images is a common task in computer vision, and thanks to advances in machine learning, it’s now possible to do it in real-time using TensorFlow.js and the MobileNet model. However, as the size of the model and the complexity of the task increase, the processing time required to perform object detection can become significant, causing delays and potential performance issues in the user interface.

One way to solve this problem is to use web workers, which allow us to offload the heavy lifting of object detection to a separate thread. This way, the user interface remains responsive, and the object detection can still be performed in real-time.

In this blog post, we will cover how to detect objects using the MobileNet model in TensorFlow.js and offload the processing to a web worker.

Getting Started

Before we get started, we need to set up a few things. First, we need to create a canvas element in our HTML file, which we will use to display the image and a div element to show the predictions. We can do this by adding the following code to our body:

<canvas id="canvas"></canvas>
<div id="predictions"></div>

Loading the Model

Now that we have set up our HTML file, we can start loading the MobileNet model. We will do this in a separate JavaScript file, which we will call worker.js.

importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.9.0/dist/tf.min.js');
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@2.0.4/dist/mobilenet.min.js');

let model;

onmessage = async (event) => {
  if (!model) {
    model = await mobilenet.load();
  }
  const image = tf.browser.fromPixels(event.data);
  const predictions = await model.classify(image);
  image.dispose();
  self.postMessage(predictions);
};

In the worker.js file, we first import TensorFlow.js and the MobileNet model. We then create a variable called model, which we will use to store the loaded model.

Next, we listen for messages from the main thread using the onmessagemethod. If model is not loaded, we load the MobileNet model using the mobilenet.load method and store it in the model variable.

When model is loaded , we first convert the image data to a TensorFlow.js tensor using the tf.browser.fromPixels method. We then pass the image tensor to the MobileNet model’s classify method to perform object detection. Finally, we dispose of the image tensor and post a message back to the main thread with the predicted objects.

Detecting Objects

Now that we have set up our web worker, we can start detecting objects. We will do this in the main JavaScript file, which we will call main.js.

    const canvas = document.getElementById('canvas');
    const ctx = canvas.getContext('2d');

    let worker;

    function showPredictions(predictions) {
        if (predictions.length) {
            document.getElementById('predictions').innerHTML = `<strong>Label:</strong> ${predictions[0].className} <br> <strong>Confidence:</strong> ${predictions[0].probability}`;
        }
    }

    function init() {
        worker = new Worker('worker.js');
        worker.onmessage = event => {
            const predictions = event.data;
            showPredictions(predictions);
        };
        const image = new Image();
        image.src = './example.png';
        image.onload = async () => {
            canvas.width = image.width;
            canvas.height = image.height;
            ctx.drawImage(image, 0, 0);
            const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
            document.getElementById('predictions').innerHTML = `Loading...`;
            worker.postMessage(imageData);
        };
    }

    init();

In the main.js file, we first get a reference to the canvas element and its context. We also create a variable called worker, which we will use to communicate with the web worker.

Next, we define a function called showPredictions, which takes an array of predicted objects display the label and confidence score of first object from the array.

Finally, we define a function called init, which creates a new web worker and loads an example image and sends it to the web worker for object detection. Once the predictions are received from the web worker, we call the showPredictionsfunction to show the predicted result.