Training a PyTorch Model across a Dask Cluster

Have many GPUs training the same model together with Dask
Training a PyTorch Model across a Dask Cluster
Try this example in seconds on Saturn Cloud

Overview

Training a PyTorch model can potentially be sped up dramatically by having the training computations done on multiple GPUs across multiple workers. This relies on PyTorches DistributedDataParallel (DDP) module to take computing the values for each batch and spread them across multiple machines/processors. So each worker computes a part of the batch, and then they are all combined to determine the loss then optimize the nodes. If you kept a network training setup the exact same except tripled the number of GPUs with DDP, you would in practice be using a batch size that is 3x bigger than our original one. Be aware, not all networks benefit from having larger batch sizes, and using PyTorch across multiple workers adds the time it takes to pass the new values between each worker.

This example builds on the introduction to PyTorch with GPU on Saturn Cloud example that trains a neural network to generate pet names. The model uses LSTM layers which are especially good at discovering patterns in sequences like text. The model takes a partially complete name and determines the probability of each possible next character in the name. Characters are randomly sampled from this distribution and added to the partial name until a stop character is generated and full name has been created. For more detail about the network design and use case, see our Saturn Cloud blog post which uses the same network architecture.

Alternatively, rather than having a Dask cluster be used to train a single PyTorch model very quickly you could have the Dask cluster train many models in parallel. We have a separate example for that situation.

Model Training

Imports

This code uses PyTorch and Dask together, and thus both libraries have to be imported. In addition, the dask_saturn package provides methods to work with a Saturn Cloud dask cluster, and dask_pytorch_ddp provides helpers when training a PyTorch model on Dask.

import uuid
import datetime
import pickle
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import urllib.request
from torch.utils.data import Dataset, DataLoader

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from dask_pytorch_ddp import dispatch, results
from dask_saturn import SaturnCluster
from dask.distributed import Client
from distributed.worker import logger

Preparing data

This code is used to get the data in the proper format in an easy to use class.

First, download the data and create the character dictionary

with urllib.request.urlopen(
    "https://saturn-public-data.s3.us-east-2.amazonaws.com/examples/pytorch/seattle_pet_licenses_cleaned.json"
) as f:
    pet_names = json.loads(f.read().decode("utf-8"))

# Our list of characters, where * represents blank and + represents stop
characters = list("*+abcdefghijklmnopqrstuvwxyz-. ")
str_len = 8

Next, create a function that will take the pet names and turn them into the formatted tensors. The Saturn Cloud blog post goes into more detail on the logic behind how to format the data.

def format_training_data(pet_names, device=None):
    def get_substrings(in_str):
        # add the stop character to the end of the name, then generate all the partial names
        in_str = in_str + "+"
        res = [in_str[0:j] for j in range(1, len(in_str) + 1)]
        return res

    pet_names_expanded = [get_substrings(name) for name in pet_names]
    pet_names_expanded = [item for sublist in pet_names_expanded for item in sublist]
    pet_names_characters = [list(name) for name in pet_names_expanded]
    pet_names_padded = [name[-(str_len + 1) :] for name in pet_names_characters]
    pet_names_padded = [
        list((str_len + 1 - len(characters)) * "*") + characters for characters in pet_names_padded
    ]
    pet_names_numeric = [[characters.index(char) for char in name] for name in pet_names_padded]

    # the final x and y data to use for training the model. Note that the x data needs to be one-hot encoded
    if device is None:
        y = torch.tensor([name[1:] for name in pet_names_numeric])
        x = torch.tensor([name[:-1] for name in pet_names_numeric])
    else:
        y = torch.tensor([name[1:] for name in pet_names_numeric], device=device)
        x = torch.tensor([name[:-1] for name in pet_names_numeric], device=device)
    x = torch.nn.functional.one_hot(x, num_classes=len(characters)).float()
    return x, y

Finally, create a PyTorch data class to manage the dataset:

class OurDataset(Dataset):
    def __init__(self, pet_names, device=None):
        self.x, self.y = format_training_data(pet_names, device)
        self.permute()

    def __getitem__(self, idx):
        idx = self.permutation[idx]
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

    def permute(self):
        self.permutation = torch.randperm(len(self.x))

Define the model architecture

This class defines the LSTM structure that the neural network will use;

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.lstm = nn.LSTM(
            input_size=len(characters),
            hidden_size=self.lstm_size,
            num_layers=4,
            batch_first=True,
            dropout=0.1,
        )
        self.fc = nn.Linear(self.lstm_size, len(characters))

    def forward(self, x):
        output, state = self.lstm(x)
        logits = self.fc(output)
        return logits

Train the model with Dask and Saturn Cloud

Next we train the model in parallel over multiple workers using Dask and Saturn. We define the train() function that will be run on each of the workers. This has much of the same training code you would see in any PyTorch training loop, with a few key differences. The data is distributed with the DistributedSampler–now each worker will only have a fraction of the data so that together all of the workers combined see each data point exactly once in an epoch. The model is also wrapped in a DDP() function call so that they can communicate with each other. The logger is used to show intermediate results in the Dask logs for each worker, and the results handler rh is used to write intermediate values back to the Jupyter server.

def train():
    num_epochs = 25
    batch_size = 16384

    worker_rank = int(dist.get_rank())
    device = torch.device(0)

    logger.info(f"Worker {worker_rank} - beginning")

    dataset = OurDataset(pet_names, device=device)
    # the distributed sampler makes it so the samples are distributed across the different workers
    sampler = DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    worker_rank = int(dist.get_rank())

    # the model has to both be passed to the GPU device, then has to be wrapped in DDP so it can communicate with the other workers
    model = Model()
    model = model.to(device)
    model = DDP(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        # the logger here logs to the Dask log of each worker, for easy debugging
        logger.info(
            f"Worker {worker_rank} - {datetime.datetime.now().isoformat()} - Beginning epoch {epoch}"
        )

        # this ensures the data is reshuffled each epoch
        sampler.set_epoch(epoch)
        dataset.permute()

        # nothing in the code for each batch is now any different than base PyTorch
        for i, (batch_x, batch_y) in enumerate(loader):
            optimizer.zero_grad()
            batch_y_pred = model(batch_x)

            loss = criterion(batch_y_pred.transpose(1, 2), batch_y)
            loss.backward()
            optimizer.step()

            logger.info(
                f"Worker {worker_rank} - {datetime.datetime.now().isoformat()} - epoch {epoch} - batch {i} - batch complete - loss {loss.item()}"
            )

        # the first rh call saves a json file with the loss from the worker at the end of the epoch
        rh.submit_result(
            f"logs/data_{worker_rank}_{epoch}.json",
            json.dumps(
                {
                    "loss": loss.item(),
                    "time": datetime.datetime.now().isoformat(),
                    "epoch": epoch,
                    "worker": worker_rank,
                }
            ),
        )
        # this saves the model. We only need to do it for one worker (so we picked worker 0)
        if worker_rank == 0:
            rh.submit_result("model.pkl", pickle.dumps(model.state_dict()))

To actually run the training job, first we spin up a Dask cluster and create a results handler object to manage the PyTorch results.

n_workers = 3
cluster = SaturnCluster(n_workers=n_workers)
client = Client(cluster)
client.wait_for_workers(n_workers)

key = uuid.uuid4().hex
rh = results.DaskResultsHandler(key)

The next block of code starts the training job on all the workers, then uses the results handler to listen for results. The process_results function will hold the Jupyter notebook until the training job is done.

futures = dispatch.run(client, train)
rh.process_results("/home/jovyan/project/training/", futures, raise_errors=False)

Lastly, we close the Dask workers

client.close()

Generating Names

To generate names, we have a function that takes the model and runs it over an over on a string generating each new character until a stop character is met.

def generate_name(model, characters, str_len):
    in_progress_name = []
    next_letter = ""
    while not next_letter == "+" and len(in_progress_name) < 30:
        # prep the data to run in the model again
        in_progress_name_padded = in_progress_name[-str_len:]
        in_progress_name_padded = (
            list((str_len - len(in_progress_name_padded)) * "*") + in_progress_name_padded
        )
        in_progress_name_numeric = [characters.index(char) for char in in_progress_name_padded]
        in_progress_name_tensor = torch.tensor(in_progress_name_numeric)
        in_progress_name_tensor = torch.nn.functional.one_hot(
            in_progress_name_tensor, num_classes=len(characters)
        ).float()
        in_progress_name_tensor = torch.unsqueeze(in_progress_name_tensor, 0)

        # get the probabilities of each possible next character by running the model
        with torch.no_grad():
            next_letter_probabilities = model(in_progress_name_tensor)

        next_letter_probabilities = next_letter_probabilities[0, -1, :]
        next_letter_probabilities = (
            torch.nn.functional.softmax(next_letter_probabilities, dim=0).detach().cpu().numpy()
        )
        next_letter_probabilities = next_letter_probabilities[1:]
        next_letter_probabilities = [
            p / sum(next_letter_probabilities) for p in next_letter_probabilities
        ]

        # determine what the actual letter is
        next_letter = characters[
            np.random.choice(len(characters) - 1, p=next_letter_probabilities) + 1
        ]
        if next_letter != "+":
            # if the next character isn't stop add the latest generated character to the name and continue
            in_progress_name.append(next_letter)
    # turn the list of characters into a single string
    pet_name = "".join(in_progress_name).title()
    return pet_name

To use the function we first need to load the model data from the training folder. That saved model state will be inserted into a parallel cuda model.

# load the model and the trained parameters
model_state = pickle.load(open("/home/jovyan/project/training/model.pkl", "rb"))
model = torch.nn.DataParallel(Model()).cuda()
model.load_state_dict(model_state)

Finally lets generate 50 names! Also let’s remove any names that would have shown up in the training data since those are less fun.

# Generate 50 names then filter out existing ones
generated_names = [generate_name(model, characters, str_len) for i in range(0, 50)]
generated_names = [name for name in generated_names if name not in pet_names]
print(generated_names)

After running the code above you should see a list of names like:

['Moicu', 'Caspa', 'Penke', 'Lare', 'Otlnys', 'Zexto', 'Toba', 'Siralto',
'Luny', 'Lit', 'Bonhe', 'Mashs', 'Riys Wargen', 'Roli', 'Sape', 'Anhyyhe',
'Lorla', 'Boupir', 'Zicka', 'Muktse', 'Musko', 'Mosdin', 'Yapfe', 'Snevi',
'Zedy', 'Cedi', 'Wivagok Rayten', 'Luzia', 'Teclyn', 'Pibty', 'Cheynet', 
'Lazyh', 'Ragopes', 'Bitt', 'Bemmen', 'Duuxy', 'Graggie', 'Rari', 'Kisi',
'Lvanxoeber', 'Bonu','Masnen', 'Isphofke', 'Myai', 'Shur', 'Lani', 'Ructli',
'Folsy', 'Icthobewlels', 'Kuet Roter']

Conclusion

We’ve now successfully trained a PyTorch neural network on a distributed set of computers with Dask, and then used it to do NLP inference! Note that depending on the size of your data, your network architecture, and other parameters particular to your situation, training over a distributed set of machines may provide different amounts of a speed benefit. For an analysis of how much this can help, see our blog post on training a neural network with multiple GPUs and Dask.

import uuid
import datetime
import pickle
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import urllib.request
from torch.utils.data import Dataset, DataLoader

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from dask_pytorch_ddp import dispatch, results
from dask_saturn import SaturnCluster
from dask.distributed import Client
from distributed.worker import logger


with urllib.request.urlopen(
    "https://saturn-public-data.s3.us-east-2.amazonaws.com/examples/pytorch/seattle_pet_licenses_cleaned.json"
) as f:
    pet_names = json.loads(f.read().decode("utf-8"))

# Our list of characters, where * represents blank and + represents stop
characters = list("*+abcdefghijklmnopqrstuvwxyz-. ")
str_len = 8


def format_training_data(pet_names, device=None):
    def get_substrings(in_str):
        # add the stop character to the end of the name, then generate all the partial names
        in_str = in_str + "+"
        res = [in_str[0:j] for j in range(1, len(in_str) + 1)]
        return res

    pet_names_expanded = [get_substrings(name) for name in pet_names]
    pet_names_expanded = [item for sublist in pet_names_expanded for item in sublist]
    pet_names_characters = [list(name) for name in pet_names_expanded]
    pet_names_padded = [name[-(str_len + 1) :] for name in pet_names_characters]
    pet_names_padded = [
        list((str_len + 1 - len(characters)) * "*") + characters for characters in pet_names_padded
    ]
    pet_names_numeric = [[characters.index(char) for char in name] for name in pet_names_padded]

    # the final x and y data to use for training the model. Note that the x data needs to be one-hot encoded
    if device is None:
        y = torch.tensor([name[1:] for name in pet_names_numeric])
        x = torch.tensor([name[:-1] for name in pet_names_numeric])
    else:
        y = torch.tensor([name[1:] for name in pet_names_numeric], device=device)
        x = torch.tensor([name[:-1] for name in pet_names_numeric], device=device)
    x = torch.nn.functional.one_hot(x, num_classes=len(characters)).float()
    return x, y


class OurDataset(Dataset):
    def __init__(self, pet_names, device=None):
        self.x, self.y = format_training_data(pet_names, device)
        self.permute()

    def __getitem__(self, idx):
        idx = self.permutation[idx]
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

    def permute(self):
        self.permutation = torch.randperm(len(self.x))


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.lstm = nn.LSTM(
            input_size=len(characters),
            hidden_size=self.lstm_size,
            num_layers=4,
            batch_first=True,
            dropout=0.1,
        )
        self.fc = nn.Linear(self.lstm_size, len(characters))

    def forward(self, x):
        output, state = self.lstm(x)
        logits = self.fc(output)
        return logits


def train():
    num_epochs = 25
    batch_size = 16384

    worker_rank = int(dist.get_rank())
    device = torch.device(0)

    logger.info(f"Worker {worker_rank} - beginning")

    dataset = OurDataset(pet_names, device=device)
    # the distributed sampler makes it so the samples are distributed across the different workers
    sampler = DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    worker_rank = int(dist.get_rank())

    # the model has to both be passed to the GPU device, then has to be wrapped in DDP so it can communicate with the other workers
    model = Model()
    model = model.to(device)
    model = DDP(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        # the logger here logs to the Dask log of each worker, for easy debugging
        logger.info(
            f"Worker {worker_rank} - {datetime.datetime.now().isoformat()} - Beginning epoch {epoch}"
        )

        # this ensures the data is reshuffled each epoch
        sampler.set_epoch(epoch)
        dataset.permute()

        # nothing in the code for each batch is now any different than base PyTorch
        for i, (batch_x, batch_y) in enumerate(loader):
            optimizer.zero_grad()
            batch_y_pred = model(batch_x)

            loss = criterion(batch_y_pred.transpose(1, 2), batch_y)
            loss.backward()
            optimizer.step()

            logger.info(
                f"Worker {worker_rank} - {datetime.datetime.now().isoformat()} - epoch {epoch} - batch {i} - batch complete - loss {loss.item()}"
            )

        # the first rh call saves a json file with the loss from the worker at the end of the epoch
        rh.submit_result(
            f"logs/data_{worker_rank}_{epoch}.json",
            json.dumps(
                {
                    "loss": loss.item(),
                    "time": datetime.datetime.now().isoformat(),
                    "epoch": epoch,
                    "worker": worker_rank,
                }
            ),
        )
        # this saves the model. We only need to do it for one worker (so we picked worker 0)
        if worker_rank == 0:
            rh.submit_result("model.pkl", pickle.dumps(model.state_dict()))


n_workers = 3
cluster = SaturnCluster(n_workers=n_workers)
client = Client(cluster)
client.wait_for_workers(n_workers)

key = uuid.uuid4().hex
rh = results.DaskResultsHandler(key)


futures = dispatch.run(client, train)
rh.process_results("/home/jovyan/project/training/", futures, raise_errors=False)


client.close()


def generate_name(model, characters, str_len):
    in_progress_name = []
    next_letter = ""
    while not next_letter == "+" and len(in_progress_name) < 30:
        # prep the data to run in the model again
        in_progress_name_padded = in_progress_name[-str_len:]
        in_progress_name_padded = (
            list((str_len - len(in_progress_name_padded)) * "*") + in_progress_name_padded
        )
        in_progress_name_numeric = [characters.index(char) for char in in_progress_name_padded]
        in_progress_name_tensor = torch.tensor(in_progress_name_numeric)
        in_progress_name_tensor = torch.nn.functional.one_hot(
            in_progress_name_tensor, num_classes=len(characters)
        ).float()
        in_progress_name_tensor = torch.unsqueeze(in_progress_name_tensor, 0)

        # get the probabilities of each possible next character by running the model
        with torch.no_grad():
            next_letter_probabilities = model(in_progress_name_tensor)

        next_letter_probabilities = next_letter_probabilities[0, -1, :]
        next_letter_probabilities = (
            torch.nn.functional.softmax(next_letter_probabilities, dim=0).detach().cpu().numpy()
        )
        next_letter_probabilities = next_letter_probabilities[1:]
        next_letter_probabilities = [
            p / sum(next_letter_probabilities) for p in next_letter_probabilities
        ]

        # determine what the actual letter is
        next_letter = characters[
            np.random.choice(len(characters) - 1, p=next_letter_probabilities) + 1
        ]
        if next_letter != "+":
            # if the next character isn't stop add the latest generated character to the name and continue
            in_progress_name.append(next_letter)
    # turn the list of characters into a single string
    pet_name = "".join(in_progress_name).title()
    return pet_name


# load the model and the trained parameters
model_state = pickle.load(open("/home/jovyan/project/training/model.pkl", "rb"))
model = torch.nn.DataParallel(Model()).cuda()
model.load_state_dict(model_state)


# Generate 50 names then filter out existing ones
generated_names = [generate_name(model, characters, str_len) for i in range(0, 50)]
generated_names = [name for name in generated_names if name not in pet_names]
print(generated_names)


['Moicu', 'Caspa', 'Penke', 'Lare', 'Otlnys', 'Zexto', 'Toba', 'Siralto',
'Luny', 'Lit', 'Bonhe', 'Mashs', 'Riys Wargen', 'Roli', 'Sape', 'Anhyyhe',
'Lorla', 'Boupir', 'Zicka', 'Muktse', 'Musko', 'Mosdin', 'Yapfe', 'Snevi',
'Zedy', 'Cedi', 'Wivagok Rayten', 'Luzia', 'Teclyn', 'Pibty', 'Cheynet', 
'Lazyh', 'Ragopes', 'Bitt', 'Bemmen', 'Duuxy', 'Graggie', 'Rari', 'Kisi',
'Lvanxoeber', 'Bonu','Masnen', 'Isphofke', 'Myai', 'Shur', 'Lani', 'Ructli',
'Folsy', 'Icthobewlels', 'Kuet Roter']