Super simple distributed hyperparameter tuning with Keras and Mongo

Datetime:2017-04-19 05:37:44         Topic: MongoDB          Share        Original >>
Here to See The Original Article!!!

One of the challenges of hyperparameter tuning a deep neural network is the time it takes to train and evaluate each set of parameters. If you’re anything like me, you often have four or five networks in mind that you want to try: different depth, different units per layer, etc.

Training and evaluating these networks in series is fine if the dataset or parameter space are small. But what if you’re evolving a neural network with a genetic algorithm , like we did in our previous post? Or what if one train/eval step takes a day to complete? We need parallelism!

Here’s a super simple way to achieve distributed hyperparameter tuning using MongoDB as a quasi pub/sub, with a controller defining jobs to process, and N workers to do the training and evaluating.

It looks like this:

All the code is available on GitHub here: Super Simple Distributed Hyperparameter Tuning

The Controller

The system starts with a controller that creates the jobs. How you choose to create jobs is up to you. In our simple example, we just randomly create a two-layer MLP with a random number of units per layer. It goes something like this:

def main():
    """Toy example of adding jobs to be processed."""
    db = Database('blog-test')
while True:
        print("Creating job.")
        network = gen_network(784, 10)  # mnist settings
        add_job('mnist', network, db)
sleep_time = random.randint(60, 120)
        print("Waiting %d seconds." % (sleep_time))

The Database

Jobs are stored in a Mongo DB. We create a Database class with the following methods:

  • insert_job — Add the job to the DB. The job document includes the dataset we want to train on (in our case, just mnist or cifar10), the Keras-generated JSON network definition, processing and processed flags, and an empty dictionary to hold the metrics
  • find_job — Returns the first available job that isn’t currently being processed and sets a flag that indicates the job is being worked on
  • score_job — Update the job to indicate it has been processed and add the resulting metrics

The DB needs to be remotely accessible to all the workers if you’re running this truly distributed. In its current state, the code only accepts a host argument since I run this on AWS instances that share a security group. If you want to connect to a remote DB, you’ll need to implement username/password (etc) arguments as well. (Pull requests welcome!)

The Workers

As in life, the workers in this setup do all the hard work. You can run as many workers as you’d like, though generally you’d want to limit to one worker per instance or GPU. When I use this to train my own models on my data, I use a cluster of eight AWS P2 instances, with one of the instances also running the controller and database.

The workflow is as follows:

  1. Check the jobs collection to see if there’s anything to do
  2. Once we have a job, get the dataset, compile the model, fit and score
  3. Update the job in the collection with the new score
  4. Repeat from step 1

Super simple, right? :+1:

Here’s the code .

Want more posts like this? Give me a follow below and take a look at my other posts . Thanks for reading!