deeplearn.js

deeplearn.js has become TensorFlow.js

We want to point you to the new website for the successor to deeplearn.js. TensorFlow.js picks up where deeplearn.js v0.5 leaves off and adds a whole host of functionality we think you will find really useful.

This site will no longer be updated and will eventually redirect to TensorFlow.js. We will be keeping this site around for a little while to make the transition to TensorFlow.js smoother. Thanks and we hope you will give TensorFlow.js a try!

Port TensorFlow models

This tutorial demonstrates training and porting a TensorFlow model to deeplearn.js. The code and all the necessary resources used in this tutorial are stored in demos/mnist.

We will use a fully connected neural network that predicts hand-written digits from the MNIST dataset. The code is forked from the official TensorFlow MNIST tutorial.

NOTE: We will refer to the base directory of the deeplearn.js repo as $BASE.

First, we clone the deeplearn.js repository and make sure we have TensorFlow installed. We cd into $BASE and train the model by running:

python demos/mnist/fully_connected_feed.py

The training should take ~1 minute and will store a model checkpoint in /tmp/tensorflow/mnist/tensorflow/mnist/logs/fully_connected_feed/.

Next, we need to port the weights from the TensorFlow checkpoint to deeplearn.js. We provide a script that does this. We run it from the $BASE directory:

python scripts/dump_checkpoints/dump_checkpoint_vars.py --model_type=tensorflow --output_dir=demos/mnist/ --checkpoint_file=/tmp/tensorflow/mnist/logs/fully_connected_feed/model.ckpt-1999

The script will save a set of files (one file per variable, and a manifest.json) in the demos/mnist directory. The manifest.json is a simple dictionary that maps variable names to files and their shapes:

{
  ...,
  "hidden1/weights": {
    "filename": "hidden1_weights",
    "shape": [784, 128]
  },
  ...
}

One last thing before we start coding - we need to run a static HTTP server from the $BASE directory:

yarn prep
./node_modules/.bin/http-server
>> Starting up http-server, serving ./
>> Available on:
>>   http://127.0.0.1:8080
>> Hit CTRL-C to stop the server

Make sure you can access manifest.json via HTTP by visiting http://localhost:8080/demos/mnist/manifest.json in the browser.

We are ready to write some deeplearn.js code!

NOTE: If you choose to write in TypeScript, make sure you compile the code to JavaScript and serve it via the static HTTP server.

To read the weights, we need to create a CheckpointLoader and point it to the manifest file. We then call loader.getAllVariables() which returns a dictionary that maps variable names to Tensors. At that point, we are ready to write our model. Here is a snippet demonstrating the use of CheckpointLoader:

import {CheckpointLoader, ENV} from 'deeplearn';
// manifest.json is in the same dir as index.html.
const varLoader = new CheckpointLoader('.');
varLoader.getAllVariables().then(async vars => {
  const math = ENV.math;

  // Get Tensor of variables casted with expected dimension.
  const hidden1W = vars['hidden1/weights'] as Tensor2D;
  const hidden1B = vars['hidden1/biases'] as Tensor1D;
  // ...

  // Write your model here.
  const hidden1 =
      math.relu(math.add(math.vectorTimesMatrix(..., hidden1W), hidden1B)) as
      Tensor1D;
  const hidden2 =
      math.relu(math.add(
          math.vectorTimesMatrix(hidden1, hidden2W), hidden2B)) as Tensor1D;

  const logits = math.add(math.vectorTimesMatrix(hidden2, softmaxW), softmaxB);

  const label = math.argMax(logits);

  console.log('Predicted label: ', await label.data());
});

For details regarding the full model code see demos/mnist/mnist.ts. The demo provides the exact implementation of the MNIST model using 3 different API:

  • buildModelMathAPI() uses the Math API. This is the API in deeplearn.js giving the most control to the user. Math commands execute immediately, like numpy.

To run the mnist demo, we provide a watch-demo script that watches and recompiles the typescript code when it changes. In addition, the script runs a simple HTTP server on 8080 that serves the static html/js files. Before you run watch-demo, make sure you kill the HTTP server we started earlier in the tutorial in order to free up the 8080 port. Then run watch-demo from $BASE pointed to the demo dir, demos/mnist:

./scripts/watch-demo demos/mnist
>> Starting up http-server, serving ./
>> Available on:
>>   http://127.0.0.1:8080
>>   http://192.168.1.5:8080
>> Hit CTRL-C to stop the server
>> 1410084 bytes written to demos/mnist/bundle.js (0.91 seconds) at 5:17:45 PM

Visit http://localhost:8080/demos/mnist/ and you should see a simple page showing test accuracy of ~90% measured using a test set of 50 mnist images stored in demos/mnist/sample_data.json. Feel free to play with the demo (e.g. make it interactive) and send us a pull request!