https://github.com/ray-project/tutorial/tree/master/examples/sharded_parameter_server.ipynb
GOAL: The goal of this exercise is to use actor handles to implement a sharded parameter server example for distributed asynchronous stochastic gradient descent.
Before doing this exercise, make sure you understand the concepts from the exercise on Actor Handles.
A parameter server is simply an object that stores the parameters (or "weights") of a machine learning model (this could be a neural network, a linear model, or something else). It exposes two methods: one for getting the parameters and one for updating the parameters.
In a typical machine learning training application, worker processes will run in an infinite loop that does the following:
The workers can operate synchronously (that is, in lock step), in which case distributed training with multiple workers is algorithmically equivalent to serial training with a larger batch of data. Alternatively, workers can operate independently and apply their updates asynchronously. The main benefit of asynchronous training is that a single slow worker will not slow down the other workers. The benefit of synchronous training is that the algorithm behavior is more predictable and reproducible.
In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray
import time
In [3]:
from zoo.common.nncontext import init_spark_on_local, init_spark_on_yarn
import numpy as np
import os
hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')
if hadoop_conf_dir:
sc = init_spark_on_yarn(
hadoop_conf=hadoop_conf_dir,
conda_name=os.environ.get("ZOO_CONDA_NAME", "zoo"), # The name of the created conda-env
num_executor=2,
executor_cores=4,
executor_memory="2g",
driver_memory="2g",
driver_cores=1,
extra_executor_memory_for_ray="3g")
else:
sc = init_spark_on_local(cores = 8, conf = {"spark.driver.memory": "2g"})
In [4]:
# It may take a while to ditribute the local environment including python and java to cluster
import ray
from zoo.ray import RayContext
ray_ctx = RayContext(sc=sc, object_store_memory="4g")
ray_ctx.init()
#ray.init(num_cpus=30, include_webui=False, ignore_reinit_error=True)
A simple parameter server can be implemented as a Python class in a few lines of code.
EXERCISE: Make the ParameterServer
class an actor.
In [5]:
dim = 10
@ray.remote
class ParameterServer(object):
def __init__(self, dim):
self.parameters = np.zeros(dim)
def get_parameters(self):
return self.parameters
def update_parameters(self, update):
self.parameters += update
ps = ParameterServer.remote(dim)
A worker can be implemented as a simple Python function that repeatedly gets the latest parameters, computes an update to the parameters, and sends the update to the parameter server.
In [6]:
@ray.remote
def worker(ps, dim, num_iters):
for _ in range(num_iters):
# Get the latest parameters.
parameters = ray.get(ps.get_parameters.remote())
# Compute an update.
update = 1e-3 * parameters + np.ones(dim)
# Update the parameters.
ps.update_parameters.remote(update)
# Sleep a little to simulate a real workload.
time.sleep(0.5)
# Test that worker is implemented correctly. You do not need to change this line.
ray.get(worker.remote(ps, dim, 1))
In [7]:
# Start two workers.
worker_results = [worker.remote(ps, dim, 100) for _ in range(2)]
As the worker tasks are executing, you can query the parameter server from the driver and see the parameters changing in the background.
In [10]:
print(ray.get(ps.get_parameters.remote()))
As the number of workers increases, the volume of updates being sent to the parameter server will increase. At some point, the network bandwidth into the parameter server machine or the computation down by the parameter server may be a bottleneck.
Suppose you have $N$ workers and $1$ parameter server, and suppose each of these is an actor that lives on its own machine. Furthermore, suppose the model size is $M$ bytes. Then sending all of the parameters from the workers to the parameter server will mean that $N * M$ bytes in total are sent to the parameter server. If $N = 100$ and $M = 10^8$, then the parameter server must receive ten gigabytes, which, assuming a network bandwidth of 10 gigabits per second, would take 8 seconds. This would be prohibitive.
On the other hand, if the parameters are sharded (that is, split) across K
parameter servers, K
is 100
, and each parameter server lives on a separate machine, then each parameter server needs to receive only 100 megabytes, which can be done in 80 milliseconds. This is much better.
EXERCISE: The code below defines a parameter server shard class. Modify this class to make ParameterServerShard
an actor. We will need to revisit this code soon and increase num_shards
.
In [11]:
@ray.remote
class ParameterServerShard(object):
def __init__(self, sharded_dim):
self.parameters = np.zeros(sharded_dim)
def get_parameters(self):
return self.parameters
def update_parameters(self, update):
self.parameters += update
total_dim = (10 ** 8) // 8 # This works out to 100MB (we have 25 million
# float64 values, which are each 8 bytes).
num_shards = 2 # The number of parameter server shards.
assert total_dim % num_shards == 0, ('In this exercise, the number of shards must '
'perfectly divide the total dimension.')
# Start some parameter servers.
ps_shards = [ParameterServerShard.remote(total_dim // num_shards) for _ in range(num_shards)]
assert hasattr(ParameterServerShard, 'remote'), ('You need to turn ParameterServerShard into an '
'actor (by using the ray.remote keyword).')
The code below implements a worker that does the following.
In [14]:
@ray.remote
def worker_task(total_dim, num_iters, *ps_shards):
# Note that ps_shards are passed in using Python's variable number
# of arguments feature. We do this because currently actor handles
# cannot be passed to tasks inside of lists or other objects.
for _ in range(num_iters):
# Get the current parameters from each parameter server.
parameter_shards = [ray.get(ps.get_parameters.remote()) for ps in ps_shards]
assert all([isinstance(shard, np.ndarray) for shard in parameter_shards]), (
'The parameter shards must be numpy arrays. Did you forget to call ray.get?')
# Concatenate them to form the full parameter vector.
parameters = np.concatenate(parameter_shards)
assert parameters.shape == (total_dim,)
# Compute an update.
update = np.ones(total_dim)
# Shard the update.
update_shards = np.split(update, len(ps_shards))
# Apply the updates to the relevant parameter server shards.
for ps, update_shard in zip(ps_shards, update_shards):
ps.update_parameters.remote(update_shard)
# Test that worker_task is implemented correctly. You do not need to change this line.
ray.get(worker_task.remote(total_dim, 1, *ps_shards))
EXERCISE: Experiment by changing the number of parameter server shards, the number of workers, and the size of the data.
NOTE: Because these processes are all running on the same machine, network bandwidth will not be a limitation and sharding the parameter server will not help. To see the difference, you would need to run the application on multiple machines. There are still regimes where sharding a parameter server can help speed up computation on the same machine (by parallelizing the computation that the parameter server processes have to do). If you want to see this effect, you should implement a synchronous training application. In the asynchronous setting, the computation is staggered and so speeding up the parameter server usually does not matter.
In [16]:
num_workers = 4
# Start some workers. Try changing various quantities and see how the
# duration changes.
start = time.time()
ray.get([worker_task.remote(total_dim, 5, *ps_shards) for _ in range(num_workers)])
print('This took {} seconds.'.format(time.time() - start))
In [ ]: