Dask and Tensorflow


In [2]:
%matplotlib inline

In [1]:
import dask.array as da
from dask import delayed
from dask_tensorflow import start_tensorflow
from distributed import Client, progress
import dask.dataframe as dd
import matplotlib.pyplot as plt

import dask.array as da
from dask import delayed

In [3]:
client = Client()

In [4]:
def get_mnist():
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)
    return mnist.train.images, mnist.train.labels

datasets = [delayed(get_mnist)() for i in range(20)]  # 20 versions of same dataset
images = [d[0] for d in datasets]
labels = [d[1] for d in datasets]

images = [da.from_delayed(im, shape=(55000, 784), dtype='float32') for im in images]
labels = [da.from_delayed(la, shape=(55000, 10), dtype='float32') for la in labels]

images = da.concatenate(images, axis=0)
labels = da.concatenate(labels, axis=0)

images


Out[4]:
dask.array<concatenate, shape=(1100000, 784), dtype=float32, chunksize=(55000, 784)>

In [5]:
images, labels = client.persist([images, labels])

In [7]:
im = images[1].compute().reshape((28, 28))
plt.imshow(im, cmap='gray')


Out[7]:
<matplotlib.image.AxesImage at 0x117ce19e8>

In [8]:
im = images.mean(axis=0).compute().reshape((28, 28))
plt.imshow(im, cmap='gray')


Out[8]:
<matplotlib.image.AxesImage at 0x117d2d668>

In [9]:
images = images.rechunk((10000, 784))
labels = labels.rechunk((10000, 10))

images = images.to_delayed().flatten().tolist()
labels = labels.to_delayed().flatten().tolist()
batches = [delayed([im, la]) for im, la in zip(images, labels)]

batches = client.compute(batches)

In [11]:
from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=1, worker=4, scorer=1)

In [13]:
import math
import tempfile
import time
from queue import Empty

IMAGE_PIXELS = 28
hidden_units = 100
learning_rate = 0.01
sync_replicas = False
replicas_to_aggregate = len(dask_spec['worker'])

In [14]:
def model(server):
    worker_device = "/job:%s/task:%d" % (server.server_def.job_name,
                                         server.server_def.task_index)
    task_index = server.server_def.task_index
    is_chief = task_index == 0

    with tf.device(tf.train.replica_device_setter(
                      worker_device=worker_device,
                      ps_device="/job:ps/cpu:0",
                      cluster=tf_spec)):

        global_step = tf.Variable(0, name="global_step", trainable=False)

        # Variables of the hidden layer
        hid_w = tf.Variable(
            tf.truncated_normal(
                [IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
                stddev=1.0 / IMAGE_PIXELS),
            name="hid_w")
        hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")

        # Variables of the softmax layer
        sm_w = tf.Variable(
            tf.truncated_normal(
                [hidden_units, 10],
                stddev=1.0 / math.sqrt(hidden_units)),
            name="sm_w")
        sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

        # Ops: located on the worker specified with task_index
        x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
        y_ = tf.placeholder(tf.float32, [None, 10])

        hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
        hid = tf.nn.relu(hid_lin)

        y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
        cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

        opt = tf.train.AdamOptimizer(learning_rate)

        if sync_replicas:
            if replicas_to_aggregate is None:
                replicas_to_aggregate = num_workers
            else:
                replicas_to_aggregate = replicas_to_aggregate

            opt = tf.train.SyncReplicasOptimizer(
                      opt,
                      replicas_to_aggregate=replicas_to_aggregate,
                      total_num_replicas=num_workers,
                      name="mnist_sync_replicas")

        train_step = opt.minimize(cross_entropy, global_step=global_step)

        if sync_replicas:
            local_init_op = opt.local_step_init_op
            if is_chief:
                local_init_op = opt.chief_init_op

            ready_for_local_init_op = opt.ready_for_local_init_op

            # Initial token and chief queue runners required by the sync_replicas mode
            chief_queue_runner = opt.get_chief_queue_runner()
            sync_init_op = opt.get_init_tokens_op()

        init_op = tf.global_variables_initializer()
        train_dir = tempfile.mkdtemp()

        if sync_replicas:
          sv = tf.train.Supervisor(
              is_chief=is_chief,
              logdir=train_dir,
              init_op=init_op,
              local_init_op=local_init_op,
              ready_for_local_init_op=ready_for_local_init_op,
              recovery_wait_secs=1,
              global_step=global_step)
        else:
          sv = tf.train.Supervisor(
              is_chief=is_chief,
              logdir=train_dir,
              init_op=init_op,
              recovery_wait_secs=1,
              global_step=global_step)

        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False,
            device_filters=["/job:ps", "/job:worker/task:%d" % task_index])

        # The chief worker (task_index==0) session will prepare the session,
        # while the remaining workers will wait for the preparation to complete.
        if is_chief:
          print("Worker %d: Initializing session..." % task_index)
        else:
          print("Worker %d: Waiting for session to be initialized..." %
                task_index)

        sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)

        if sync_replicas and is_chief:
          # Chief worker will start the chief queue runner and call the init op.
          sess.run(sync_init_op)
          sv.start_queue_runners(sess, [chief_queue_runner])

        return sess, x, y_, train_step, global_step, cross_entropy

In [17]:
def ps_task():
    with local_client() as c:
        c.worker.tensorflow_server.join()

In [18]:
def scoring_task():
    with local_client() as c:
        # Scores Channel
        scores = c.channel('scores', maxlen=10)

        # Make Model
        server = c.worker.tensorflow_server
        sess, _, _, _, _, cross_entropy = model(c.worker.tensorflow_server)

        # Testing Data
        from tensorflow.examples.tutorials.mnist import input_data
        mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)
        test_data = {x: mnist.validation.images,
                     y_: mnist.validation.labels}

        # Main Loop
        while True:
            score = sess.run(cross_entropy, feed_dict=test_data)
            scores.append(float(score))

            time.sleep(1)

In [19]:
def worker_task():
    with local_client() as c:
        scores = c.channel('scores')
        num_workers = replicas_to_aggregate = len(dask_spec['worker'])

        server = c.worker.tensorflow_server
        queue = c.worker.tensorflow_queue

        # Make model
        sess, x, y_, train_step, global_step, _= model(c.worker.tensorflow_server)

        # Main loop
        while not scores or scores.data[-1] > 1000:
            try:
                batch = queue.get(timeout=0.5)
            except Empty:
                continue

            train_data = {x: batch[0],
                          y_: batch[1]}

            sess.run([train_step, global_step], feed_dict=train_data)

In [21]:
ps_tasks = [client.submit(ps_task, workers=worker)
            for worker in dask_spec['ps']]

worker_tasks = [client.submit(worker_task, workers=addr, pure=False)
                for addr in dask_spec['worker']]

scorer_task = client.submit(scoring_task, workers=dask_spec['scorer'][0])

In [23]:
from distributed.worker_client import get_worker

def transfer_dask_to_tensorflow(batch):
    worker = get_worker()
    worker.tensorflow_queue.put(batch)

dump = client.map(transfer_dask_to_tensorflow, batches,
                  workers=dask_spec['worker'], pure=False)