Writing a custom training loop#
Overview#
Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.
Proceed with the following steps to learn how to write a custom training loop for a simple, fully connected model for training on the MNIST dataset.
Note, the following steps will only take you through the absolute minimum code required to run a simple, small model on the Cerebras Wafer Scale Cluster. To extend the script to feature things like learning rate scheduling, gradient scaling, etc. please continue to the further reading section to learn more about these topics.
Prerequisites#
You have installed the cstorch-class-docs in your environment.
Validate the package installation#
To check whether the cstorch-class-docs is installed correctly, issue the following command:
import cerebras.pytorch as cstorch
Note
From here on, we will be using cstorch
as the alias for
cerebras.pytorch
.
Configure the Cerebras Wafer Scale Cluster#
To configure the Cerebras Wafer-Scale cluster, construct a
ClusterConfig
object and use it
to construct a cerebras.pytorch.backend
object:
cluster_config = cstorch.distributed.ClusterConfig(
mgmt_address=mgmt_address,
max_wgt_servers=1,
max_act_per_csx=1,
num_workers_per_csx=1,
)
backend = cstorch.backend(
"CSX",
cluster_config=cluster_config,
)
See the class documentation for
ClusterConfig
for all the options
configurable.
Note
Most options have reasonable defaults and do not need to be changed.
Define your model#
When using the Cerebras PyTorch API, you can define your model in the same way you would in a Vanilla PyTorch workflow:
import torch
import torch.nn.functional as F
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 256)
self.fc2 = torch.nn.Linear(256, 10)
def forward(self, inputs):
inputs = torch.flatten(inputs, 1)
outputs = F.relu(self.fc1(inputs))
return F.relu(self.fc2(outputs))
model = Model()
Note
Weight initialization for large models can cause out-of-memory errors. Not only that, but initializing extremely large models eagerly can be very slow. See the page on Efficient weight initialization to see how to work around this issue.
Compile your model#
Once the model has been instantiated, compile the model by calling the
cerebras.pytorch.compile
.
You must pass in the backend you wish to compile the model with. You can simply
pass in the type of backend if you wish to use all default arguments, or you can
instantiate a backend object using cerebras.pytorch.backend
, as done
above, to customize it:
compiled_model = cstorch.compile(model, backend)
Note
The call to cstorch.compile
doesn’t actually compile the model.
Similar to torch.compile
, it only prepares the model for compilation.
Compilation only happens after the first iteration once the input shapes
are known.
Optimize Model Parameters#
To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras.pytorch.optim, e.g.
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
Note
If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing custom optimizers
DataLoaders#
To send data to the Wafer-Scale cluster, you must wrap your PyTorch
dataloader with cerebras.pytorch.utils.data.DataLoader
, e.g.
def get_torch_dataloader(batch_size, train):
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(
"./data",
train=train,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
),
)
return torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
training_dataloader = cstorch.utils.data.DataLoader(
get_torch_dataloader, batch_size=64, train=True
)
The Cerebras PyTorch dataloader takes in some callable that returns a PyTorch dataloader. It must be done this way so that every single worker can create their own PyTorch dataloader instance to maximize distributed parallelism.
Define the Training Step#
To run a single training iteration on the Cerebras Wafer-Scale cluster, we must
first, capture everything that is intended to run on the cluster. To do this,
define a function which contains everything that happens in a single
training iteration, and decorate it using cerebras.pytorch.trace
.
For example:
loss_fn = torch.nn.CrossEntropyLoss()
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
This function gets traced and sent to the cluster for compilation and execution.
Define an Execution#
To program an execution run on the Cerebras Wafer-Scale cluster, you must define
an instance of the cerebras.pytorch.utils.data.DataExecutor
, e.g.
train_executor = cstorch.utils.data.DataExecutor(
training_dataloader,
num_steps=100,
checkpoint_steps=50,
)
It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.
Train your model#
Once the above is defined, you can iterate through the executor to train your model.
@cstorch.step_closure
def print_loss(mode, loss: torch.Tensor, step: int):
print(f"{mode} Loss {step}: {loss.item()}")
@cstorch.checkpoint_closure
def save_checkpoint(step):
cstorch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
},
f"checkpoint_{step}.mdl",
)
global_step = 0
for inputs, targets in train_executor:
loss = training_step(inputs, targets)
print_loss("Training", loss, global_step)
global_step += 1
save_checkpoint(global_step)
Note
Notice how the loss was passed into a function decorated by
step_closure
. This is required to retrieve the
loss value from the Cerebras Wafer Scale Cluster before it can be used.
Please see the page on step closures for more
details.
Note
Also, notice how checkpoints are saved inside a function decorated by
checkpoint_closure
. This is required to
retrieve the model weights and optimizer state back from the Cerebras
Wafer Scale Cluster before it can be saved.
Please see the page on saving checkpoints
Putting it all together#
Combining all of the above steps, we can create a super minimal training script for a simple, fully connected model training on the MNIST dataset:
# Import the Cerebras PyTorch module
import cerebras.pytorch as cstorch
# Define a model
import torch
import torch.nn.functional as F
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 256)
self.fc2 = torch.nn.Linear(256, 10)
def forward(self, inputs):
inputs = torch.flatten(inputs, 1)
outputs = F.relu(self.fc1(inputs))
return F.relu(self.fc2(outputs))
backend = cstorch.backend(
"CSX",
cluster_config=cstorch.distributed.ClusterConfig(
mgmt_address=mgmt_address,
max_wgt_servers=1,
max_act_per_csx=1,
num_workers_per_csx=1,
),
)
model = Model()
# Compile the model
compiled_model = cstorch.compile(model, backend=backend)
# Define an optimizer
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Define a data loader
def get_torch_dataloader(batch_size, train):
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(
"./data",
train=train,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
),
)
return torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
training_dataloader = cstorch.utils.data.DataLoader(
get_torch_dataloader, batch_size=64, train=True
)
# Define the training step
loss_fn = torch.nn.CrossEntropyLoss()
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
@cstorch.step_closure
def print_loss(loss: torch.Tensor, step: int):
print(f"Train Loss {step}: {loss.item()}")
@cstorch.checkpoint_closure
def save_checkpoint(step):
cstorch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
},
f"checkpoint_{step}.mdl",
)
global_step = 0
train_executor = cstorch.utils.data.DataExecutor(
training_dataloader,
num_steps=100,
checkpoint_steps=50,
)
model.train()
for inputs, targets in train_executor:
loss = training_step(inputs, targets)
print_loss(loss, global_step)
global_step += 1
save_checkpoint(global_step)
Further reading#
End-to-end Examples