In [ ]:
# In task 0:
cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
server = tf.train.Server(cluster, job_name="local", task_index=0)
# In task 1:
cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
server = tf.train.Server(cluster, job_name="local", task_index=1)
with tf.device("/job:ps/task:0"):
weights_1 = tf.Variable(...)
biases_1 = tf.Variable(...)
with tf.device("/job:ps/task:1"):
weights_2 = tf.Variable(...)
biases_2 = tf.Variable(...)
with tf.device("/job:worker/task:7"):
input, labels = ...
layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2)
# ...
train_op = ...
with tf.Session("grpc://worker7.example.com:2222") as sess:
for _ in range(10000):
sess.run(train_op)
cluster_spec = tf.train.ClusterSpec({
"ps": FLAGS.ps_hosts.split(","),
"worker": FLAGS.worker_hosts.split(","),
})
server = tf.train.Server(cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
worker_device = "/job:worker/task:{}".format(FLAGS.task_index)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
return (
tf.train.replica_device_setter(
worker_device=worker_device,
cluster=cluster_spec),
server.target,)
def main(unused_argv):
device, target = device_and_target()
with tf.device(device):
images, labels = inputs(FLAGS.batch_size)
logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
loss = mnist.loss(logits, labels)
train_op = mnist.training(loss, FLAGS.learning_rate)
with tf.train.MonitoredTrainingSession(
master=target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir=FLAGS.train_dir) as sess:
while not sess.should_stop():
sess.run(train_op)
if __name__ == "__main__":
tf.app.run()