Parallel training with Thinc and Ray

This notebook is based off one of Ray's tutorials and shows how to use Thinc and Ray to implement parallel training. It includes implementations for both synchronous and asynchronous parameter server training.


In [ ]:
# To let ray install its own version in Colab
!pip uninstall -y pyarrow
# You might need to restart the Colab runtime

In [ ]:
!pip install --upgrade "thinc>=8.0.0a0" ml_datasets ray psutil setproctitle

Let's start with a simple model and config file. You can edit the CONFIG string within the file, or copy it out to a separate file and use Config.from_disk to load it from a path. The [ray] section contains the settings to use for Ray. (We're using a config for convenience, but you don't have to – you can also just hard-code the values.)


In [ ]:
import thinc
from thinc.api import chain, Relu, Softmax

@thinc.registry.layers("relu_relu_softmax.v1")
def make_relu_relu_softmax(hidden_width: int, dropout: float):
    return chain(
        Relu(hidden_width, dropout=dropout),
        Relu(hidden_width, dropout=dropout),
        Softmax(),
    )

CONFIG = """
[training]
iterations = 200
batch_size = 128

[evaluation]
batch_size = 256
frequency = 10

[model]
@layers = "relu_relu_softmax.v1"
hidden_width = 128
dropout = 0.2

[optimizer]
@optimizers = "Adam.v1"

[ray]
num_workers = 2
object_store_memory = 3000000000
num_cpus = 2
"""

Just like in the original Ray tutorial, we're using the MNIST data (via our ml-datasets package) and are setting up two helper functions:

  1. get_data_loader: Return shuffled batches of a given batch size.
  2. evaluate: Evaluate a model on batches of data.

In [ ]:
import ml_datasets

MNIST = ml_datasets.mnist()

def get_data_loader(model, batch_size):
    (train_X, train_Y), (dev_X, dev_Y) = MNIST
    train_batches = model.ops.multibatch(batch_size, train_X, train_Y, shuffle=True)
    dev_batches = model.ops.multibatch(batch_size, dev_X, dev_Y, shuffle=True)
    return train_batches, dev_batches

def evaluate(model, batch_size):
    dev_X, dev_Y = MNIST[1]
    correct = 0
    total = 0
    for X, Y in model.ops.multibatch(batch_size, dev_X, dev_Y):
        Yh = model.predict(X)
        correct += (Yh.argmax(axis=1) == Y.argmax(axis=1)).sum()
        total += Yh.shape[0]
    return correct / total

Setting up Ray

Getters and setters for gradients and weights

Using Thinc's Model.walk method, we can implement the following helper functions to get and set weights and parameters for each node in a model's tree. Those functions can later be used by the parameter server and workers.


In [ ]:
from collections import defaultdict

def get_model_weights(model):
    params = defaultdict(dict)
    for node in model.walk():
        for name in node.param_names:
            if node.has_param(name):
                params[node.id][name] = node.get_param(name)
    return params

def set_model_weights(model, params):
    for node in model.walk():
        for name, param in params[node.id].items():
            node.set_param(name, param)

def get_model_grads(model):
    grads = defaultdict(dict)
    for node in model.walk():
        for name in node.grad_names:
            grads[node.id][name] = node.get_grad(name)
    return grads

def set_model_grads(model, grads):
    for node in model.walk():
        for name, grad in grads[node.id].items():
            node.set_grad(name, grad)

Defining the Parameter Server

The parameter server will hold a copy of the model. During training, it will:

  1. Receive gradients and apply them to its model.
  2. Send the updated model back to the workers.

The @ray.remote decorator defines a remote process. It wraps the ParameterServerclass and allows users to instantiate it as a remote actor. (Source)

Here, the ParameterServer is initialized with a model and optimizer, and has a method to apply gradients received by the workers and a method to get the weights from the current model, using the helper functions defined above.


In [ ]:
import ray

@ray.remote
class ParameterServer:
    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer

    def apply_gradients(self, *worker_grads):
        summed_gradients = defaultdict(dict)
        for grads in worker_grads:
            for node_id, node_grads in grads.items():
                for name, grad in node_grads.items():
                    if name in summed_gradients[node_id]:
                        summed_gradients[node_id][name] += grad
                    else:
                        summed_gradients[node_id][name] = grad.copy()
        set_model_grads(self.model, summed_gradients)
        self.model.finish_update(self.optimizer)
        return get_model_weights(self.model)

    def get_weights(self):
        return get_model_weights(self.model)

Defining the Worker

The worker will also hold a copy of the model. During training it will continuously evaluate data and send gradients to the parameter server. The worker will synchronize its model with the Parameter Server model weights. (Source)

To compute the gradients during training, we can call the model on a batch of data (and set is_train=True). This returns the predictions and a backprop callback to update the model.


In [ ]:
from thinc.api import fix_random_seed

@ray.remote
class DataWorker:
    def __init__(self, model, batch_size=128, seed=0):
        self.model = model
        fix_random_seed(seed)
        self.data_iterator = iter(get_data_loader(model, batch_size)[0])
        self.batch_size = batch_size

    def compute_gradients(self, weights):
        set_model_weights(self.model, weights)
        try:
            data, target = next(self.data_iterator)
        except StopIteration:  # When the epoch ends, start a new epoch.
            self.data_iterator = iter(get_data_loader(model, self.batch_size)[0])
            data, target = next(self.data_iterator)
        guesses, backprop = self.model(data, is_train=True)
        backprop((guesses - target) / target.shape[0])
        return get_model_grads(self.model)

Setting up the model

Using the CONFIG defined above, we can load the settings and set up the model and optimizer. Thinc's registry.make_from_config will parse the config, resolve all references to registered functions and return a dict.


In [ ]:
from thinc.api import registry, Config
C = registry.make_from_config(Config().from_str(CONFIG))
C

We didn't specify all the dimensions in the model, so we need to pass in a batch of data to finish initialization. This lets Thinc infer the missing shapes.


In [ ]:
optimizer = C["optimizer"]
model = C["model"]

(train_X, train_Y), (dev_X, dev_Y) = MNIST
model.initialize(X=train_X[:5], Y=train_Y[:5])

Training

Synchronous Parameter Server training

We can now create a synchronous parameter server training scheme:

  1. Call ray.init with the settings defined in the config.
  2. Instantiate a process for the ParameterServer.
  3. Create multiple workers (n_workers, as defined in the config).

Though this is not specifically mentioned in the Ray tutorial, we're setting a different random seed for the workers here. Otherwise the workers may iterate over the batches in the same order.


In [ ]:
ray.init(
    ignore_reinit_error=True,
    object_store_memory=C["ray"]["object_store_memory"],
    num_cpus=C["ray"]["num_cpus"],
)
ps = ParameterServer.remote(model, optimizer)
workers = []
for i in range(C["ray"]["num_workers"]):
    worker = DataWorker.remote(model, batch_size=C["training"]["batch_size"], seed=i)
    workers.append(worker)

On each iteration, we now compute the gradients for each worker. After all gradients are available, ParameterServer.apply_gradients is called to calculate the update. The frequency setting in the evaluation config specifies how often to evaluate – for instance, a frequency of 10 means we're only evaluating every 10th epoch.


In [ ]:
current_weights = ps.get_weights.remote()
for i in range(C["training"]["iterations"]):
    gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
    current_weights = ps.apply_gradients.remote(*gradients)
    if i % C["evaluation"]["frequency"] == 0:
        set_model_weights(model, ray.get(current_weights))
        accuracy = evaluate(model, C["evaluation"]["batch_size"])
        print(f"{i} \taccuracy: {accuracy:.3f}")
print(f"Final \taccuracy: {accuracy:.3f}")
ray.shutdown()

Asynchronous Parameter Server Training

Here, workers will asynchronously compute the gradients given its current weights and send these gradients to the parameter server as soon as they are ready. When the Parameter server finishes applying the new gradient, the server will send back a copy of the current weights to the worker. The worker will then update the weights and repeat. (Source)

The setup looks the same and we can reuse the config. Make sure to call ray.shutdown() to clean up resources and processes before calling ray.init again.


In [ ]:
ray.init(
    ignore_reinit_error=True,
    object_store_memory=C["ray"]["object_store_memory"],
    num_cpus=C["ray"]["num_cpus"],
)
ps = ParameterServer.remote(model, optimizer)
workers = []
for i in range(C["ray"]["num_workers"]):
    worker = DataWorker.remote(model, batch_size=C["training"]["batch_size"], seed=i)
    workers.append(worker)

In [ ]:
current_weights = ps.get_weights.remote()
gradients = {}
for worker in workers:
    gradients[worker.compute_gradients.remote(current_weights)] = worker

for i in range(C["training"]["iterations"] * C["ray"]["num_workers"]):
    ready_gradient_list, _ = ray.wait(list(gradients))
    ready_gradient_id = ready_gradient_list[0]
    worker = gradients.pop(ready_gradient_id)
    current_weights = ps.apply_gradients.remote(*[ready_gradient_id])
    gradients[worker.compute_gradients.remote(current_weights)] = worker
    if i % C["evaluation"]["frequency"] == 0:
        set_model_weights(model, ray.get(current_weights))
        accuracy = evaluate(model, C["evaluation"]["batch_size"])
        print(f"{i} \taccuracy: {accuracy:.3f}")
print(f"Final \taccuracy: {accuracy:.3f}")
ray.shutdown()