Writing a custom training loop#

Overview#

Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.

Proceed with the following steps to learn how to write a custom training loop for a simple, fully connected model for training on the MNIST dataset.

Note, the following steps will only take you through the absolute minimum code required to run a simple, small model on the Cerebras Wafer Scale Cluster. To extend the script to feature things like learning rate scheduling, gradient scaling, etc. please continue to the further reading section to learn more about these topics.

Prerequisites#

You have installed the cstorch-class-docs in your environment.

Validate the package installation#

To check whether the cstorch-class-docs is installed correctly, issue the following command:

import cerebras.pytorch as cstorch

Note

From here on, we will be using cstorch as the alias for cerebras.pytorch.

Configure the Cerebras Wafer Scale Cluster#

To configure the Cerebras Wafer-Scale cluster, construct a ClusterConfig object and use it to construct a cerebras.pytorch.backend object:

cluster_config = cstorch.distributed.ClusterConfig(
    mgmt_address=mgmt_address,
    max_wgt_servers=1,
    max_act_per_csx=1,
    num_workers_per_csx=1,
)

backend = cstorch.backend(
    "CSX",
    cluster_config=cluster_config,
)

See the class documentation for ClusterConfig for all the options configurable.

Note

Most options have reasonable defaults and do not need to be changed.

Define your model#

When using the Cerebras PyTorch API, you can define your model in the same way you would in a Vanilla PyTorch workflow:

import torch
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)

    def forward(self, inputs):
        inputs = torch.flatten(inputs, 1)
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))


model = Model()

Note

Weight initialization for large models can cause out-of-memory errors. Not only that, but initializing extremely large models eagerly can be very slow. See the page on Efficient weight initialization to see how to work around this issue.

Compile your model#

Once the model has been instantiated, compile the model by calling the cerebras.pytorch.compile.

You must pass in the backend you wish to compile the model with. You can simply pass in the type of backend if you wish to use all default arguments, or you can instantiate a backend object using cerebras.pytorch.backend, as done above, to customize it:

compiled_model = cstorch.compile(model, backend)

Note

The call to cstorch.compile doesn’t actually compile the model. Similar to torch.compile, it only prepares the model for compilation. Compilation only happens after the first iteration once the input shapes are known.

Optimize Model Parameters#

To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras.pytorch.optim, e.g.

optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Note

If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing custom optimizers

DataLoaders#

To send data to the Wafer-Scale cluster, you must wrap your PyTorch dataloader with cerebras.pytorch.utils.data.DataLoader, e.g.

def get_torch_dataloader(batch_size, train):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "./data",
        train=train,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        ),
    )

    return torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )


training_dataloader = cstorch.utils.data.DataLoader(
    get_torch_dataloader, batch_size=64, train=True
)

The Cerebras PyTorch dataloader takes in some callable that returns a PyTorch dataloader. It must be done this way so that every single worker can create their own PyTorch dataloader instance to maximize distributed parallelism.

Define the Training Step#

To run a single training iteration on the Cerebras Wafer-Scale cluster, we must first, capture everything that is intended to run on the cluster. To do this, define a function which contains everything that happens in a single training iteration, and decorate it using cerebras.pytorch.trace.

For example:

loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss

This function gets traced and sent to the cluster for compilation and execution.

Define an Execution#

To program an execution run on the Cerebras Wafer-Scale cluster, you must define an instance of the cerebras.pytorch.utils.data.DataExecutor, e.g.

train_executor = cstorch.utils.data.DataExecutor(
    training_dataloader,
    num_steps=100,
    checkpoint_steps=50,
)

It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.

Train your model#

Once the above is defined, you can iterate through the executor to train your model.

@cstorch.step_closure
def print_loss(mode, loss: torch.Tensor, step: int):
    print(f"{mode} Loss {step}: {loss.item()}")


@cstorch.checkpoint_closure
def save_checkpoint(step):
    cstorch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"checkpoint_{step}.mdl",
    )

global_step = 0

for inputs, targets in train_executor:
    loss = training_step(inputs, targets)
    print_loss("Training", loss, global_step)
    global_step += 1
    save_checkpoint(global_step)

Note

Notice how the loss was passed into a function decorated by step_closure. This is required to retrieve the loss value from the Cerebras Wafer Scale Cluster before it can be used. Please see the page on step closures for more details.

Note

Also, notice how checkpoints are saved inside a function decorated by checkpoint_closure. This is required to retrieve the model weights and optimizer state back from the Cerebras Wafer Scale Cluster before it can be saved. Please see the page on saving checkpoints

Putting it all together#

Combining all of the above steps, we can create a super minimal training script for a simple, fully connected model training on the MNIST dataset:

# Import the Cerebras PyTorch module
import cerebras.pytorch as cstorch

# Define a model
import torch
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)

    def forward(self, inputs):
        inputs = torch.flatten(inputs, 1)
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))

backend = cstorch.backend(
    "CSX",
    cluster_config=cstorch.distributed.ClusterConfig(
        mgmt_address=mgmt_address,
        max_wgt_servers=1,
        max_act_per_csx=1,
        num_workers_per_csx=1,
    ),
)

model = Model()

# Compile the model
compiled_model = cstorch.compile(model, backend=backend)

# Define an optimizer
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


# Define a data loader
def get_torch_dataloader(batch_size, train):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "./data",
        train=train,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        ),
    )

    return torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )


training_dataloader = cstorch.utils.data.DataLoader(
    get_torch_dataloader, batch_size=64, train=True
)

# Define the training step
loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss


@cstorch.step_closure
def print_loss(loss: torch.Tensor, step: int):
    print(f"Train Loss {step}: {loss.item()}")


@cstorch.checkpoint_closure
def save_checkpoint(step):
    cstorch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"checkpoint_{step}.mdl",
    )

global_step = 0

train_executor = cstorch.utils.data.DataExecutor(
    training_dataloader,
    num_steps=100,
    checkpoint_steps=50,
)
model.train()
for inputs, targets in train_executor:
    loss = training_step(inputs, targets)
    print_loss(loss, global_step)
    global_step += 1
    save_checkpoint(global_step)

Further reading#