This notebook is adapted from:

https://github.com/ray-project/tutorial/tree/master/examples/sharded_parameter_server.ipynb

Sharded Parameter Servers

GOAL: The goal of this exercise is to use actor handles to implement a sharded parameter server example for distributed asynchronous stochastic gradient descent.

Before doing this exercise, make sure you understand the concepts from the exercise on Actor Handles.

Parameter Servers

A parameter server is simply an object that stores the parameters (or "weights") of a machine learning model (this could be a neural network, a linear model, or something else). It exposes two methods: one for getting the parameters and one for updating the parameters.

In a typical machine learning training application, worker processes will run in an infinite loop that does the following:

  1. Get the latest parameters from the parameter server.
  2. Compute an update to the parameters (using the current parameters and some data).
  3. Send the update to the parameter server.

The workers can operate synchronously (that is, in lock step), in which case distributed training with multiple workers is algorithmically equivalent to serial training with a larger batch of data. Alternatively, workers can operate independently and apply their updates asynchronously. The main benefit of asynchronous training is that a single slow worker will not slow down the other workers. The benefit of synchronous training is that the algorithm behavior is more predictable and reproducible.


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import ray
import time

Init SparkContext


In [3]:
from zoo.common.nncontext import init_spark_on_local, init_spark_on_yarn
import numpy as np
import os
hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')

if hadoop_conf_dir:
    sc = init_spark_on_yarn(
    hadoop_conf=hadoop_conf_dir,
    conda_name=os.environ.get("ZOO_CONDA_NAME", "zoo"), # The name of the created conda-env
    num_executor=2,
    executor_cores=4,
    executor_memory="2g",
    driver_memory="2g",
    driver_cores=1,
    extra_executor_memory_for_ray="3g")
else:
    sc = init_spark_on_local(cores = 8, conf = {"spark.driver.memory": "2g"})


Current pyspark location is : /root/anaconda2/envs/ray_train/lib/python3.6/site-packages/pyspark/__init__.py
Start to pack current python env
Collecting packages...
Packing environment at '/root/anaconda2/envs/ray_train' to '/tmp/tmp7qvxc3o2/python_env.tar.gz'
[########################################] | 100% Completed | 34.4s
Packing has been completed: /tmp/tmp7qvxc3o2/python_env.tar.gz
pyspark_submit_args is:  --master yarn --deploy-mode client --archives /tmp/tmp7qvxc3o2/python_env.tar.gz#python_env --num-executors 2  --executor-cores 4 --executor-memory 2g pyspark-shell 

In [4]:
# It may take a while to ditribute the local environment including python and java to cluster
import ray
from zoo.ray import RayContext
ray_ctx = RayContext(sc=sc, object_store_memory="4g")
ray_ctx.init()
#ray.init(num_cpus=30, include_webui=False, ignore_reinit_error=True)


Start to launch the JVM guarding process
JVM guarding process has been successfully launched
Start to launch ray on cluster
Start to launch ray on local
Executing command: ray start --redis-address 172.16.0.158:34046 --redis-password  123456 --num-cpus 0 --object-store-memory 400000000
2019-07-18 07:09:19,971	WARNING worker.py:1341 -- WARNING: Not updating worker name since `setproctitle` is not installed. Install this with `pip install setproctitle` (or ray[debug]) to enable monitoring of worker processes.
2019-07-18 07:09:19,855	INFO services.py:409 -- Waiting for redis server at 172.16.0.158:34046 to respond...
2019-07-18 07:09:19,858	INFO scripts.py:363 -- Using IP address 172.16.0.102 for this node.
2019-07-18 07:09:19,861	INFO node.py:511 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-07-18_15-09-10_137772_188428/logs.
2019-07-18 07:09:19,862	INFO services.py:1441 -- Starting the Plasma object store with 0.4 GB memory using /dev/shm.
2019-07-18 07:09:19,887	INFO scripts.py:371 -- 
Started Ray on this node. If you wish to terminate the processes that have been started, run

    ray stop


A simple parameter server can be implemented as a Python class in a few lines of code.

EXERCISE: Make the ParameterServer class an actor.


In [5]:
dim = 10
@ray.remote
class ParameterServer(object):
    def __init__(self, dim):
        self.parameters = np.zeros(dim)
    
    def get_parameters(self):
        return self.parameters
    
    def update_parameters(self, update):
        self.parameters += update


ps = ParameterServer.remote(dim)

A worker can be implemented as a simple Python function that repeatedly gets the latest parameters, computes an update to the parameters, and sends the update to the parameter server.


In [6]:
@ray.remote
def worker(ps, dim, num_iters):
    for _ in range(num_iters):
        # Get the latest parameters.
        parameters = ray.get(ps.get_parameters.remote())
        # Compute an update.
        update = 1e-3 * parameters + np.ones(dim)
        # Update the parameters.
        ps.update_parameters.remote(update)
        # Sleep a little to simulate a real workload.
        time.sleep(0.5)

# Test that worker is implemented correctly. You do not need to change this line.
ray.get(worker.remote(ps, dim, 1))

In [7]:
# Start two workers.
worker_results = [worker.remote(ps, dim, 100) for _ in range(2)]

As the worker tasks are executing, you can query the parameter server from the driver and see the parameters changing in the background.


In [10]:
print(ray.get(ps.get_parameters.remote()))


[19.16281869 19.16281869 19.16281869 19.16281869 19.16281869 19.16281869
 19.16281869 19.16281869 19.16281869 19.16281869]

Sharding a Parameter Server

As the number of workers increases, the volume of updates being sent to the parameter server will increase. At some point, the network bandwidth into the parameter server machine or the computation down by the parameter server may be a bottleneck.

Suppose you have $N$ workers and $1$ parameter server, and suppose each of these is an actor that lives on its own machine. Furthermore, suppose the model size is $M$ bytes. Then sending all of the parameters from the workers to the parameter server will mean that $N * M$ bytes in total are sent to the parameter server. If $N = 100$ and $M = 10^8$, then the parameter server must receive ten gigabytes, which, assuming a network bandwidth of 10 gigabits per second, would take 8 seconds. This would be prohibitive.

On the other hand, if the parameters are sharded (that is, split) across K parameter servers, K is 100, and each parameter server lives on a separate machine, then each parameter server needs to receive only 100 megabytes, which can be done in 80 milliseconds. This is much better.

EXERCISE: The code below defines a parameter server shard class. Modify this class to make ParameterServerShard an actor. We will need to revisit this code soon and increase num_shards.


In [11]:
@ray.remote
class ParameterServerShard(object):
    def __init__(self, sharded_dim):
        self.parameters = np.zeros(sharded_dim)
    
    def get_parameters(self):
        return self.parameters
    
    def update_parameters(self, update):
        self.parameters += update


total_dim = (10 ** 8) // 8  # This works out to 100MB (we have 25 million
                            # float64 values, which are each 8 bytes).
num_shards = 2  # The number of parameter server shards.

assert total_dim % num_shards == 0, ('In this exercise, the number of shards must '
                                     'perfectly divide the total dimension.')

# Start some parameter servers.
ps_shards = [ParameterServerShard.remote(total_dim // num_shards) for _ in range(num_shards)]

assert hasattr(ParameterServerShard, 'remote'), ('You need to turn ParameterServerShard into an '
                                                 'actor (by using the ray.remote keyword).')

The code below implements a worker that does the following.

  1. Gets the latest parameters from all of the parameter server shards.
  2. Concatenates the parameters together to form the full parameter vector.
  3. Computes an update to the parameters.
  4. Partitions the update into one piece for each parameter server.
  5. Applies the right update to each parameter server shard.

In [14]:
@ray.remote
def worker_task(total_dim, num_iters, *ps_shards):
    # Note that ps_shards are passed in using Python's variable number
    # of arguments feature. We do this because currently actor handles
    # cannot be passed to tasks inside of lists or other objects.
    for _ in range(num_iters):
        # Get the current parameters from each parameter server.
        parameter_shards = [ray.get(ps.get_parameters.remote()) for ps in ps_shards]
        assert all([isinstance(shard, np.ndarray) for shard in parameter_shards]), (
               'The parameter shards must be numpy arrays. Did you forget to call ray.get?')
        # Concatenate them to form the full parameter vector.
        parameters = np.concatenate(parameter_shards)
        assert parameters.shape == (total_dim,)

        # Compute an update.
        update = np.ones(total_dim)
        # Shard the update.
        update_shards = np.split(update, len(ps_shards))
        
        # Apply the updates to the relevant parameter server shards.
        for ps, update_shard in zip(ps_shards, update_shards):
            ps.update_parameters.remote(update_shard)


# Test that worker_task is implemented correctly. You do not need to change this line.
ray.get(worker_task.remote(total_dim, 1, *ps_shards))

EXERCISE: Experiment by changing the number of parameter server shards, the number of workers, and the size of the data.

NOTE: Because these processes are all running on the same machine, network bandwidth will not be a limitation and sharding the parameter server will not help. To see the difference, you would need to run the application on multiple machines. There are still regimes where sharding a parameter server can help speed up computation on the same machine (by parallelizing the computation that the parameter server processes have to do). If you want to see this effect, you should implement a synchronous training application. In the asynchronous setting, the computation is staggered and so speeding up the parameter server usually does not matter.


In [16]:
num_workers = 4

# Start some workers. Try changing various quantities and see how the
# duration changes.
start = time.time()
ray.get([worker_task.remote(total_dim, 5, *ps_shards) for _ in range(num_workers)])
print('This took {} seconds.'.format(time.time() - start))


This took 4.21185827255249 seconds.

In [ ]: