Weights & Biases (with Dask Cluster)
Overview
This example shows how to use Weights & Biases to monitor the progress of model training on resource with a Dask Cluster in Saturn Cloud. This is the extension of the single machine Weights & Biases example which does not use a Dask cluster. This example will use PyTorch and a Dask cluster of workers for image classification. It will use the Stanford Dogs dataset, and starting with a pre-trained version of Resnet50 will use transfer learning to make it perform better at dog image identification.
Example code
Imports
import os
import math
import torch
import re
import s3fs
from torch import nn, optim
from torchvision import transforms, models
from torch.utils.data.sampler import RandomSampler
from dask_pytorch_ddp import data, dispatch
from torch.nn.parallel import DistributedDataParallel as DDP
from dask_saturn import SaturnCluster
from dask.distributed import Client
import torch.distributed as dist
import wandb
Set up Weights & Biases
Import the Weights & Biases library, and confirm that you are logged in.
The Start Script in this example uses your Weights & Biases token to log in. The resource will try and read it from an environment variable named WANDB_LOGIN
, which you can set up in the Credentials section of Saturn Cloud. This is important because all the workers in your cluster need to have this token. This credential needs to be set up before the cluster is started. If when running the wandb.login()
command you are asked to provide your Weights & Biases API key then you did not correctly set up your credential in Saturn Cloud. Once you add the token to the Credentials page of Saturn Cloud you’ll need to restart the resource.
wandb.login()
Dask Cluster Specific Elements
Because this task uses a Dask cluster, we need to load a few extra libraries, and ensure our cluster is running.
cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(2)
client
Label Formatting
These utilities ensure the training data labels correspond to the pretrained model’s label expectations.
# Load label dataset
s3 = s3fs.S3FileSystem(anon=True)
with s3.open("s3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt") as f:
imagenetclasses = [line.strip() for line in f.readlines()]
# Format labels to match pretrained Resnet
def replace_label(dataset_label, model_labels):
label_string = re.search("n[0-9]+-([^/]+)", dataset_label).group(1)
for i in model_labels:
i = str(i).replace("{", "").replace("}", "")
model_label_str = re.search("""b["'][0-9]+: ["']([^\\/]+)["'],["']""", str(i))
model_label_idx = re.search("""b["']([0-9]+):""", str(i)).group(1)
if re.search(str(label_string).replace("_", " "), str(model_label_str).replace("_", " ")):
return i, model_label_idx
Set Model Specifications
Here you can assign your model hyperparameters, as well as identifying where the training data is housed on S3. All these parameters, as well as some extra elements like Notes and Tags, are tracked by Weights & Biases for you.
model_params = {
"n_epochs": 6,
"batch_size": 64,
"base_lr": 0.0003,
"downsample_to": 0.5, # Value represents percent of training data you want to use
"bucket": "saturn-public-data",
"prefix": "dogs/Images",
"pretrained_classes": imagenetclasses,
}
wbargs = {
**model_params,
"classes": 120,
"Notes": "baseline",
"Tags": ["downsample", "cluster", "gpu", "6wk", "subsample"],
"Group": "DDP",
"dataset": "StanfordDogs",
"architecture": "ResNet",
}
Training Function
This function encompasses the training task.
- Load model and wrap it in PyTorch’s Distributed Data Parallel function
- Initialize Weights & Biases run
- Set up DataLoader to iterate over training data
- Perform training tasks
- Write model performance data to Weights & Biases
def simple_train_cluster(
bucket, prefix, batch_size, downsample_to, n_epochs, base_lr, pretrained_classes
):
# os.environ["DASK_DISTRIBUTED__WORKER__DAEMON"] = "False"
os.environ["WANDB_START_METHOD"] = "thread"
worker_rank = int(dist.get_rank())
# --------- Format params --------- #
device = torch.device("cuda")
net = models.resnet50(pretrained=True) # True means we start with the imagenet version
model = net.to(device)
model = DDP(model)
# --------- Start wandb --------- #
if worker_rank == 0:
wandb.init(config=wbargs, project="wandb_saturncloud_demo")
wandb.watch(model)
# --------- Set up eval --------- #
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.AdamW(model.parameters(), lr=base_lr, eps=1e-06)
# --------- Retrieve data for training --------- #
transform = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(250), transforms.ToTensor()]
)
# Because we want to load our images directly and lazily from S3,
# we use a custom Dataset class called S3ImageFolder.
whole_dataset = data.S3ImageFolder(bucket, prefix, transform=transform, anon=True)
# Format target labels
new_class_to_idx = {
x: int(replace_label(x, pretrained_classes)[1]) for x in whole_dataset.classes
}
whole_dataset.class_to_idx = new_class_to_idx
# ------ Create dataloader ------- #
train_loader = torch.utils.data.DataLoader(
whole_dataset,
sampler=RandomSampler(
whole_dataset,
replacement=True,
num_samples=math.floor(len(whole_dataset) * downsample_to),
),
batch_size=batch_size,
num_workers=0,
)
# Using the OneCycleLR learning rate schedule
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=base_lr, steps_per_epoch=len(train_loader), epochs=n_epochs
)
# ------ Prepare wandb Table for predictions ------- #
if worker_rank == 0:
columns = ["image", "label", "prediction", "score"]
preds_table = wandb.Table(columns=columns)
# --------- Start Training ------- #
for epoch in range(n_epochs):
count = 0
model.train()
for inputs, labels in train_loader:
# zero the parameter gradients
optimizer.zero_grad()
inputs, labels = inputs.to(device), labels.to(device)
# Run model iteration
outputs = model(inputs)
# Format results
pred_idx, preds = torch.max(outputs, 1)
perct = [
torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)
]
loss = criterion(outputs, labels)
correct = (preds == labels).sum().item()
loss.backward()
optimizer.step()
scheduler.step()
# Log your metrics to wandb
if worker_rank == 0:
logs = {
"train/train_loss": loss.item(),
"train/learning_rate": scheduler.get_last_lr()[0],
"train/correct": correct,
"train/epoch": epoch + count / len(train_loader),
"train/count": count,
}
# Occasionally some images to ensure the image data looks correct
if count % 25 == 0:
logs["examples/example_images"] = wandb.Image(
inputs[:5], caption=f"Step: {count}"
)
# Log some predictions to wandb during final epoch for analysis
if epoch == max(range(n_epochs)) and count % 4 == 0:
for i in range(len(labels)):
preds_table.add_data(wandb.Image(inputs[i]), labels[i], preds[i], perct[i])
# Log metrics to wandb
wandb.log(logs)
count += 1
# Upload your predictions table for analysis
if worker_rank == 0:
predictions_artifact = wandb.Artifact(
"train_predictions_" + str(wandb.run.id), type="train_predictions"
)
predictions_artifact.add(preds_table, "train_predictions")
wandb.run.log_artifact(predictions_artifact)
# Close your wandb run
wandb.run.finish()
Run Model
To run the model, we use the dask-pytorch-ddp
function dispatch.run()
. This takes our client, our training function, and our dictionary of model parameters. You can monitor the model run on all workers using the Dask dashboard, or monitor the performance of Worker 0 on Weights & Biases.
client.restart() # Clears memory on cluster- optional but recommended.
%%time
futures = dispatch.run(client, simple_train_cluster, **model_params)
futures
# If one or more worker jobs errors, this will describe the issue
futures[0].result()
At this point, you can view the Weights & Biases dashboard to see the performance of the model and system resources utilization in real time!
Conclusion
From this example we were able to see that by using Weights & Biases you can monitor performance of each work in a Dask cluster on Saturn Cloud. Adding Weights & Biases to a Dask cluster is just as easy as adding it to a single machine, so this can be a great tool for monitor models
import os
import math
import torch
import re
import s3fs
from torch import nn, optim
from torchvision import transforms, models
from torch.utils.data.sampler import RandomSampler
from dask_pytorch_ddp import data, dispatch
from torch.nn.parallel import DistributedDataParallel as DDP
from dask_saturn import SaturnCluster
from dask.distributed import Client
import torch.distributed as dist
import wandb
wandb.login()
cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(2)
client
# Load label dataset
s3 = s3fs.S3FileSystem(anon=True)
with s3.open("s3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt") as f:
imagenetclasses = [line.strip() for line in f.readlines()]
# Format labels to match pretrained Resnet
def replace_label(dataset_label, model_labels):
label_string = re.search("n[0-9]+-([^/]+)", dataset_label).group(1)
for i in model_labels:
i = str(i).replace("{", "").replace("}", "")
model_label_str = re.search("""b["'][0-9]+: ["']([^\\/]+)["'],["']""", str(i))
model_label_idx = re.search("""b["']([0-9]+):""", str(i)).group(1)
if re.search(str(label_string).replace("_", " "), str(model_label_str).replace("_", " ")):
return i, model_label_idx
model_params = {
"n_epochs": 6,
"batch_size": 64,
"base_lr": 0.0003,
"downsample_to": 0.5, # Value represents percent of training data you want to use
"bucket": "saturn-public-data",
"prefix": "dogs/Images",
"pretrained_classes": imagenetclasses,
}
wbargs = {
**model_params,
"classes": 120,
"Notes": "baseline",
"Tags": ["downsample", "cluster", "gpu", "6wk", "subsample"],
"Group": "DDP",
"dataset": "StanfordDogs",
"architecture": "ResNet",
}
def simple_train_cluster(
bucket, prefix, batch_size, downsample_to, n_epochs, base_lr, pretrained_classes
):
# os.environ["DASK_DISTRIBUTED__WORKER__DAEMON"] = "False"
os.environ["WANDB_START_METHOD"] = "thread"
worker_rank = int(dist.get_rank())
# --------- Format params --------- #
device = torch.device("cuda")
net = models.resnet50(pretrained=True) # True means we start with the imagenet version
model = net.to(device)
model = DDP(model)
# --------- Start wandb --------- #
if worker_rank == 0:
wandb.init(config=wbargs, project="wandb_saturncloud_demo")
wandb.watch(model)
# --------- Set up eval --------- #
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.AdamW(model.parameters(), lr=base_lr, eps=1e-06)
# --------- Retrieve data for training --------- #
transform = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(250), transforms.ToTensor()]
)
# Because we want to load our images directly and lazily from S3,
# we use a custom Dataset class called S3ImageFolder.
whole_dataset = data.S3ImageFolder(bucket, prefix, transform=transform, anon=True)
# Format target labels
new_class_to_idx = {
x: int(replace_label(x, pretrained_classes)[1]) for x in whole_dataset.classes
}
whole_dataset.class_to_idx = new_class_to_idx
# ------ Create dataloader ------- #
train_loader = torch.utils.data.DataLoader(
whole_dataset,
sampler=RandomSampler(
whole_dataset,
replacement=True,
num_samples=math.floor(len(whole_dataset) * downsample_to),
),
batch_size=batch_size,
num_workers=0,
)
# Using the OneCycleLR learning rate schedule
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=base_lr, steps_per_epoch=len(train_loader), epochs=n_epochs
)
# ------ Prepare wandb Table for predictions ------- #
if worker_rank == 0:
columns = ["image", "label", "prediction", "score"]
preds_table = wandb.Table(columns=columns)
# --------- Start Training ------- #
for epoch in range(n_epochs):
count = 0
model.train()
for inputs, labels in train_loader:
# zero the parameter gradients
optimizer.zero_grad()
inputs, labels = inputs.to(device), labels.to(device)
# Run model iteration
outputs = model(inputs)
# Format results
pred_idx, preds = torch.max(outputs, 1)
perct = [
torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)
]
loss = criterion(outputs, labels)
correct = (preds == labels).sum().item()
loss.backward()
optimizer.step()
scheduler.step()
# Log your metrics to wandb
if worker_rank == 0:
logs = {
"train/train_loss": loss.item(),
"train/learning_rate": scheduler.get_last_lr()[0],
"train/correct": correct,
"train/epoch": epoch + count / len(train_loader),
"train/count": count,
}
# Occasionally some images to ensure the image data looks correct
if count % 25 == 0:
logs["examples/example_images"] = wandb.Image(
inputs[:5], caption=f"Step: {count}"
)
# Log some predictions to wandb during final epoch for analysis
if epoch == max(range(n_epochs)) and count % 4 == 0:
for i in range(len(labels)):
preds_table.add_data(wandb.Image(inputs[i]), labels[i], preds[i], perct[i])
# Log metrics to wandb
wandb.log(logs)
count += 1
# Upload your predictions table for analysis
if worker_rank == 0:
predictions_artifact = wandb.Artifact(
"train_predictions_" + str(wandb.run.id), type="train_predictions"
)
predictions_artifact.add(preds_table, "train_predictions")
wandb.run.log_artifact(predictions_artifact)
# Close your wandb run
wandb.run.finish()
client.restart() # Clears memory on cluster- optional but recommended.
%%time
futures = dispatch.run(client, simple_train_cluster, **model_params)
futures
# If one or more worker jobs errors, this will describe the issue
futures[0].result()