Writing a Custom Training Loop#
Overview#
Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.
Proceed with the following steps to learn how to write a custom training loop for a simple, fully connected model for training on the MNIST dataset.
Prerequisites#
You have installed the cerebras_pytorch package in your environment.
Validate the package installation#
To check whether the cerebras_pytorch package is installed correctly, issue the following command:
import cerebras_pytorch as cstorch
Note
From here on, we will be using cstorch
as the alias for cerebras_pytorch
Define your model#
When using the Cerebras PyTorch API, you can define your model in the same way you would in a Vanilla PyTorch workflow:
import torch
import torch.nn.functional as F
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 256)
self.fc2 = torch.nn.Linear(256, 10)
def forward(self, inputs):
inputs = torch.flatten(inputs, 1)
outputs = F.relu(self.fc1(inputs))
return F.relu(self.fc2(outputs))
model = Model()
Note
Weight initialization for extremely large models can cause out-of-memory errors. See the page on Efficient Weight Initialization to see how to work around this issue.
Compile your model#
Once the model has been instantiated, compile the model by calling the
cerebras_pytorch.compile
, e.g.
compiled_model = cstorch.compile(model, backend="CSX")
You must pass in the backend you wish to compile the model with. You can simply
pass in the type of backend if you wish to use all default arguments, or you can
instantiate a backend object using cerebras_pytorch.backend
to
customize it, e.g.
backend = cstorch.backend("CSX", compile_dir="/path/to/compile")
compiled_model = cstorch.compile(model, backend)
Note
The call to cstorch.compile
doesn’t actually compile the model.
Similar to torch.compile
it only prepares the model for compilation.
Compilation only happens after the first iteration once the input shapes
are known.
Optimize Model Parameters#
To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras_pytorch.optim, e.g.
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
Note
For convenience, we also include a configuration helper function
configure_optimizer
.
If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing Custom Optimizers
DataLoaders#
To send data to the Wafer-Scale cluster, you must wrap your PyTorch
dataloader with cerebras_pytorch.utils.data.DataLoader
, e.g.
def get_torch_dataloader(batch_size):
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(
"/path/to/data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.float16)
),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
)
)
return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
cerebras_loader = cstorch.utils.data.DataLoader(get_torch_dataloader, batch_size=64)
The Cerebras PyTorch dataloader takes in some callable that returns a PyTorch dataloader. It must be done this way so that every single worker can create their own PyTorch dataloader instance to maximize distributed parallelism.
Define the Training Step#
To run a single training iteration on the Cerebras Wafer-Scale cluster, we must
first, capture everything that is intended to run on the cluster. To do this,
define a function which contains everything that happens in a single
training iteration, and decorate it using cerebras_pytorch.trace
.
For example:
loss_fn = torch.nn.CrossEntropyLoss()
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
This function gets traced and sent to the cluster for compilation and execution.
Define an Execution#
To program an execution run on the Cerebras Wafer-Scale cluster, you must define
an instance of the cerebras_pytorch.utils.data.DataExecutor
, e.g.
executor = cstorch.utils.data.DataExecutor(
dataloader, num_steps=1000, checkpoint_steps=100
)
It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.
Train your model#
Once the above is defined, you can iterate through the executor to train your model.
@cstorch.step_closure
def print_loss(loss: torch.Tensor):
print(f"Loss: {loss.item()}")
for inputs, targets in executor:
loss = training_step(inputs, targets)
print_loss(loss)
Note
Notice how the loss was passed into a function decorated by
step_closure
. This is required to retrieve the
loss value from the Cerebras Wafer Scale Cluster before it can be used.
Please see the page on step closures for more
details.
Putting it all together#
Combining all of the above steps, we can create a super minimal training script for a simple, fully connected model training on the MNIST dataset:
import cerebras_pytorch as cstorch
import torch
import torch.nn.functional as F
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 256)
self.fc2 = torch.nn.Linear(256, 10)
def forward(self, inputs):
outputs = F.relu(self.fc1(inputs))
return F.relu(self.fc2(outputs))
model = Model()
compiled_model = cstorch.compile(model, backend="CSX")
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
def get_torch_dataloader(batch_size):
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(
"/path/to/data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.float16)
),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
)
)
return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
cerebras_loader = cstorch.utils.data.DataLoader(get_torch_dataloader, batch_size=64)
loss_fn = torch.nn.CrossEntropyLoss()
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
executor = cstorch.utils.data.DataExecutor(
cerebras_loader, num_steps=1000, checkpoint_steps=100
)
for inputs, targets in executor:
loss = training_step(inputs, targets)
Note
For a full-fledged training script example, see End-to-end Examples.