In [ ]:
import tensorflow as tf
cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
In [ ]:
server0 = tf.train.Server(cluster, job_name="local", task_index=0)
print(server0)
In [ ]:
server1 = tf.train.Server(cluster, job_name="local", task_index=1)
print(server1)
In [ ]:
import tensorflow as tf
n = 2
c1 = tf.Variable([])
c2 = tf.Variable([])
def matpow(M, n):
if n < 1:
return M
else:
return tf.matmul(M, matpow(M, n-1))
In [ ]:
import datetime
with tf.device("/job:local/task:0/cpu:0"):
A = tf.random_normal(shape=[1000, 1000])
c1 = matpow(A,n)
with tf.device("/job:local/task:1/cpu:0"):
B = tf.random_normal(shape=[1000, 1000])
c2 = matpow(B,n)
with tf.Session("grpc://127.0.0.1:2222") as sess:
sum = c1 + c2
start_time = datetime.datetime.now()
print(sess.run(sum))
print("Execution time: "
+ str(datetime.datetime.now() - start_time))
In [ ]:
with tf.device("/job:local/task:0/gpu:0"):
A = tf.random_normal(shape=[1000, 1000])
c1 = matpow(A,n)
with tf.device("/job:local/task:1/cpu:0"):
B = tf.random_normal(shape=[1000, 1000])
c2 = matpow(B,n)
with tf.Session("grpc://127.0.0.1:2222") as sess:
sum = c1 + c2
start_time = datetime.datetime.now()
print(sess.run(sum))
print("Execution time: "
+ str(datetime.datetime.now() - start_time))
In [ ]:
with tf.device("/job:local/task:0/gpu:0"):
A = tf.random_normal(shape=[1000, 1000])
c1 = matpow(A,n)
with tf.device("/job:local/task:1/gpu:0"):
B = tf.random_normal(shape=[1000, 1000])
c2 = matpow(B,n)
with tf.Session("grpc://127.0.0.1:2222") as sess:
sum = c1 + c2
start_time = datetime.datetime.now()
print(sess.run(sum))
print("Execution time: "
+ str(datetime.datetime.now() - start_time))
In [ ]:
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:0",
cluster=cluster)):
A = tf.random_normal(shape=[1000, 1000])
c1 = matpow(A,n)
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:1",
cluster=cluster)):
B = tf.random_normal(shape=[1000, 1000])
c2 = matpow(B,n)
with tf.Session("grpc://127.0.0.1:2222") as sess:
sum = c1 + c2
start_time = datetime.datetime.now()
print(sess.run(sum))
print("Multi node computation time: "
+ str(datetime.datetime.now() - start_time))