In [ ]:
# In task 0:
cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
server = tf.train.Server(cluster, job_name="local", task_index=0)


# In task 1:
cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
server = tf.train.Server(cluster, job_name="local", task_index=1)



with tf.device("/job:ps/task:0"):
    weights_1 = tf.Variable(...)
    biases_1 = tf.Variable(...)

with tf.device("/job:ps/task:1"):
    weights_2 = tf.Variable(...)
    biases_2 = tf.Variable(...)

with tf.device("/job:worker/task:7"):
    input, labels = ...
    layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
    logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2)
    # ...
    train_op = ...

with tf.Session("grpc://worker7.example.com:2222") as sess:
    for _ in range(10000):
        sess.run(train_op)

    
    
    
    
    
    
    
    
    
    
    
    

cluster_spec = tf.train.ClusterSpec({
    "ps": FLAGS.ps_hosts.split(","),
    "worker": FLAGS.worker_hosts.split(","),
})
server = tf.train.Server(cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
  
if FLAGS.job_name == "ps":
    server.join()

worker_device = "/job:worker/task:{}".format(FLAGS.task_index)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
return (
      tf.train.replica_device_setter(
          worker_device=worker_device,
          cluster=cluster_spec),
      server.target,)


def main(unused_argv):


device, target = device_and_target()
with tf.device(device):
    images, labels = inputs(FLAGS.batch_size)
    logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
    loss = mnist.loss(logits, labels)
    train_op = mnist.training(loss, FLAGS.learning_rate)

with tf.train.MonitoredTrainingSession(
      master=target,
      is_chief=(FLAGS.task_index == 0),
      checkpoint_dir=FLAGS.train_dir) as sess:
    while not sess.should_stop():
        sess.run(train_op)


if __name__ == "__main__":
    tf.app.run()