In [2]:
%matplotlib inline
In [1]:
import dask.array as da
from dask import delayed
from dask_tensorflow import start_tensorflow
from distributed import Client, progress
import dask.dataframe as dd
import matplotlib.pyplot as plt
import dask.array as da
from dask import delayed
In [3]:
client = Client()
In [4]:
def get_mnist():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)
return mnist.train.images, mnist.train.labels
datasets = [delayed(get_mnist)() for i in range(20)] # 20 versions of same dataset
images = [d[0] for d in datasets]
labels = [d[1] for d in datasets]
images = [da.from_delayed(im, shape=(55000, 784), dtype='float32') for im in images]
labels = [da.from_delayed(la, shape=(55000, 10), dtype='float32') for la in labels]
images = da.concatenate(images, axis=0)
labels = da.concatenate(labels, axis=0)
images
Out[4]:
In [5]:
images, labels = client.persist([images, labels])
In [7]:
im = images[1].compute().reshape((28, 28))
plt.imshow(im, cmap='gray')
Out[7]:
In [8]:
im = images.mean(axis=0).compute().reshape((28, 28))
plt.imshow(im, cmap='gray')
Out[8]:
In [9]:
images = images.rechunk((10000, 784))
labels = labels.rechunk((10000, 10))
images = images.to_delayed().flatten().tolist()
labels = labels.to_delayed().flatten().tolist()
batches = [delayed([im, la]) for im, la in zip(images, labels)]
batches = client.compute(batches)
In [11]:
from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=1, worker=4, scorer=1)
In [13]:
import math
import tempfile
import time
from queue import Empty
IMAGE_PIXELS = 28
hidden_units = 100
learning_rate = 0.01
sync_replicas = False
replicas_to_aggregate = len(dask_spec['worker'])
In [14]:
def model(server):
worker_device = "/job:%s/task:%d" % (server.server_def.job_name,
server.server_def.task_index)
task_index = server.server_def.task_index
is_chief = task_index == 0
with tf.device(tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=tf_spec)):
global_step = tf.Variable(0, name="global_step", trainable=False)
# Variables of the hidden layer
hid_w = tf.Variable(
tf.truncated_normal(
[IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
stddev=1.0 / IMAGE_PIXELS),
name="hid_w")
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
# Variables of the softmax layer
sm_w = tf.Variable(
tf.truncated_normal(
[hidden_units, 10],
stddev=1.0 / math.sqrt(hidden_units)),
name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with task_index
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
opt = tf.train.AdamOptimizer(learning_rate)
if sync_replicas:
if replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = replicas_to_aggregate
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
name="mnist_sync_replicas")
train_step = opt.minimize(cross_entropy, global_step=global_step)
if sync_replicas:
local_init_op = opt.local_step_init_op
if is_chief:
local_init_op = opt.chief_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner()
sync_init_op = opt.get_init_tokens_op()
init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()
if sync_replicas:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
recovery_wait_secs=1,
global_step=global_step)
else:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
recovery_wait_secs=1,
global_step=global_step)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
device_filters=["/job:ps", "/job:worker/task:%d" % task_index])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
if is_chief:
print("Worker %d: Initializing session..." % task_index)
else:
print("Worker %d: Waiting for session to be initialized..." %
task_index)
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
if sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op.
sess.run(sync_init_op)
sv.start_queue_runners(sess, [chief_queue_runner])
return sess, x, y_, train_step, global_step, cross_entropy
In [17]:
def ps_task():
with local_client() as c:
c.worker.tensorflow_server.join()
In [18]:
def scoring_task():
with local_client() as c:
# Scores Channel
scores = c.channel('scores', maxlen=10)
# Make Model
server = c.worker.tensorflow_server
sess, _, _, _, _, cross_entropy = model(c.worker.tensorflow_server)
# Testing Data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)
test_data = {x: mnist.validation.images,
y_: mnist.validation.labels}
# Main Loop
while True:
score = sess.run(cross_entropy, feed_dict=test_data)
scores.append(float(score))
time.sleep(1)
In [19]:
def worker_task():
with local_client() as c:
scores = c.channel('scores')
num_workers = replicas_to_aggregate = len(dask_spec['worker'])
server = c.worker.tensorflow_server
queue = c.worker.tensorflow_queue
# Make model
sess, x, y_, train_step, global_step, _= model(c.worker.tensorflow_server)
# Main loop
while not scores or scores.data[-1] > 1000:
try:
batch = queue.get(timeout=0.5)
except Empty:
continue
train_data = {x: batch[0],
y_: batch[1]}
sess.run([train_step, global_step], feed_dict=train_data)
In [21]:
ps_tasks = [client.submit(ps_task, workers=worker)
for worker in dask_spec['ps']]
worker_tasks = [client.submit(worker_task, workers=addr, pure=False)
for addr in dask_spec['worker']]
scorer_task = client.submit(scoring_task, workers=dask_spec['scorer'][0])
In [23]:
from distributed.worker_client import get_worker
def transfer_dask_to_tensorflow(batch):
worker = get_worker()
worker.tensorflow_queue.put(batch)
dump = client.map(transfer_dask_to_tensorflow, batches,
workers=dask_spec['worker'], pure=False)