Sample Training and Eval Scripts for Cerebras PyTorch API#

For a full overview of the Cerebras PyTorch Experimental API and its components please see How to port your code using Cerebras PyTorch API (Experimental).

Define Dataloader and Input Functions#

The dataloader must be defined in a file separate from the model and main execution loop, as shown below.

dataloader.py#

import os

import torch
from torchvision import datasets, transforms


def get_mnist_dataset(train=True):
    data_dir = os.path.join(os.getcwd(), 'mnist_dataset')
    return datasets.MNIST(
        data_dir,
        train=train,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
                transforms.Lambda(
                    lambda x: torch.as_tensor(x, dtype=torch.float16)
                ),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        ),
    )


def input_fn_train(batch_size=4, drop_last=False):
    train_dataset = get_mnist_dataset(train=True)
    return torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, drop_last=drop_last, shuffle=True,
    )


def input_fn_eval(batch_size=4, drop_last=False):
    eval_dataset = get_mnist_dataset(train=False)
    return torch.utils.data.DataLoader(
        eval_dataset, batch_size=batch_size, drop_last=drop_last, shuffle=False,
    )

Training Example#

In the same directory as the dataloader, create a training script as follows:

training.py#

""" Example of training script for FC MNIST model on CSX with Weight Streaming. """
import logging
import os

import cerebras_pytorch.experimental as cstorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from dataloader import input_fn_train, input_fn_eval


class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_layers = []
        input_size = 784

        hidden_size = 50
        depth = 10
        hidden_sizes = [hidden_size] * depth

        for hidden_size in hidden_sizes:
            fc_layer = nn.Linear(input_size, hidden_size)
            self.fc_layers.append(fc_layer)
            input_size = hidden_size
        self.fc_layers = nn.ModuleList(self.fc_layers)
        self.last_layer = nn.Linear(input_size, 10)

        self.nonlin = nn.ReLU()
        self.dropout = nn.Dropout(p=0.0)

    def forward(self, inputs):
        x = torch.flatten(inputs, 1)
        for fc_layer in self.fc_layers:
            x = fc_layer(x)
            x = self.nonlin(x)
            x = self.dropout(x)

        pred_logits = self.last_layer(x)
        outputs = F.log_softmax(pred_logits, dim=1)
        return outputs

# CONFIGURABLE VARIABLES FOR THIS SCRIPT
# Can optionally move these arguments to a params file and configure from there.
MODEL_DIR = "./"
COMPILE_ONLY = False
VALIDATE_ONLY = False

TRAINING_STEPS = 10
CKPT_STEPS = 5
LOG_STEPS = 5

# Checkpoint-related configurations
CHECKPOINT_STEPS = 5
IS_PRETRAINED_CHECKPOINT = False

def main_training_loop(cs_config: cstorch.utils.CSConfig):
    """Main training loop for FC MNIST model"""

    torch.manual_seed(2023)

    backend = cstorch.backend(
        BACKEND,
        artifact_dir=MODEL_DIR,
        compile_dir="./compile_dir",
        compile_only=COMPILE_ONLY,
        validate_only=VALIDATE_ONLY,
    )

    with backend.device:
        model = MNISTModel()

    model = cstorch.compile(model, backend)

    # Define loss function for FC MNIST Model
    loss_fn = torch.nn.NLLLoss()

    # Define the optimizer used for training.
    # This example will be using SGD from cerebras_pytorch.experimental.optim.Optimizer
    # For a complete list of optimizers available in the experimental API, please see
    # https://docs.cerebras.net/en/latest/wsc/port/porting-pytorch-to-cs/cstorch-api.html#initializing-the-optimizer
    optimizer = cstorch.optim.configure_optimizer(
        optimizer_type="SGD", params=model.parameters(), lr=0.01, momentum=0.0,
    )

    # Optionally define the learning rate scheduler
    # This example will be using LinearLR from cerebras_pytorch.experimental.optim.lr_scheduler
    # For a complete list of lr schedulers available in the experimental API, please see
    # https://docs.cerebras.net/en/latest/wsc/port/porting-pytorch-to-cs/cstorch-api.html#initializing-the-learning-rate-scheduler
    lr_params = {
        "scheduler": "Linear",
        "initial_learning_rate": 0.01,
        "end_learning_rate": 0.001,
        "total_iters": 5,
    }
    lr_scheduler = cstorch.optim.configure_lr_scheduler(optimizer, lr_params)

    # Define gradient scaling parameters.
    grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")

    loss_values = []
    total_steps = 0

    @cstorch.step_closure
    def accumulate_loss(loss):
        nonlocal loss_values
        nonlocal total_steps

        loss_values.append(loss.item())
        total_steps += 1

    lr_values = []

    @cstorch.step_closure
    def save_learning_rate():
        lr_values.append(lr_scheduler.get_last_lr())

    # Define method for saving ckpts
    @cstorch.checkpoint_closure
    def save_checkpoint(step):
        logging.info(f"Saving checkpoint at step {step}")

        checkpoint_file = os.path.join(MODEL_DIR, f"checkpoint_{step}.mdl")

        state_dict = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "grad_scalar": grad_scaler.state_dict(),
        }

        state_dict["global_step"] = step

        cstorch.save(state_dict, checkpoint_file)
        logging.info(f"Saved checkpoint {checkpoint_file}")

    global_step = 0

    # Define the actual training loop
    @cstorch.compile_step
    def training_step(batch):
        inputs, targets = batch
        outputs = model(inputs)

        loss = loss_fn(outputs, targets)

        cstorch.amp.optimizer_step(
            loss, optimizer, grad_scaler,
        )

        lr_scheduler.step()

        save_learning_rate()

        accumulate_loss(loss)

        # Save the loss value to be able to plot the loss curve
        cstorch.summarize_scalar("loss", loss)

        return loss

    # Define post-training loop if you are interested in tracking summaries, etc.
    writer = cstorch.utils.tensorboard.SummaryWriter(log_dir=os.path.join(MODEL_DIR, "train"))

    @cstorch.step_closure
    def post_training_step(loss):
        if LOG_STEPS and global_step % LOG_STEPS == 0:
            # Define the logging any way desired.
            logging.info(
                f"| Train: {model.device} "
                f"Step={global_step}, "
                f"Loss={loss.item():.5f}"
            )

        # Add handling for NaN values
        if torch.isnan(loss).any().item():
            raise ValueError(
                "NaN loss detected. "
                "Please try different hyperparameters "
                "such as the learning rate, batch size, etc."
            )
        if torch.isinf(loss).any().item():
            raise ValueError("inf loss detected.")

        for group, lr in enumerate(lr_scheduler.get_last_lr()):
            writer.add_scalar(f"lr.{group}", lr, global_step)

    # PERFORM TRAINING LOOPS
    batch_size = 4
    dataloader = cstorch.utils.data.DataLoader(input_fn_train, batch_size)
    executor = cstorch.utils.data.DataExecutor(
        dataloader,
        num_steps=TRAINING_STEPS,
        checkpoint_steps=CHECKPOINT_STEPS,
        writer=writer,
        cs_config=cs_config,
    )

    for _, batch in enumerate(executor):
        loss = training_step(batch)

        global_step += 1

        post_training_step(loss)

        if CHECKPOINT_STEPS and global_step % CHECKPOINT_STEPS == 0:
            save_checkpoint(global_step)

if __name__ == "__main__":

    logging.getLogger().setLevel(logging.INFO)

    os.makedirs(os.path.join(os.getcwd(),'mnist_dataset'), exist_ok=True)

    cs_config = cstorch.utils.CSConfig(
        mount_dirs=[os.getcwd()],
        python_paths=[os.getcwd()],
        max_wgt_servers=1,
        num_workers_per_csx=1,
        max_act_per_csx=1,
    )

    main_training_loop(cs_config)

Eval Example#

In the same directory as the dataloader, create a evaluation script as follows:

eval.py#

""" Example of training script for FC MNIST model on CSX with Weight Streaming. """
import logging
import os

import cerebras_pytorch.experimental as cstorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms


class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_layers = []
        input_size = 784

        hidden_size = 50
        depth = 10
        hidden_sizes = [hidden_size] * depth

        for hidden_size in hidden_sizes:
            fc_layer = nn.Linear(input_size, hidden_size)
            self.fc_layers.append(fc_layer)
            input_size = hidden_size
        self.fc_layers = nn.ModuleList(self.fc_layers)
        self.last_layer = nn.Linear(input_size, 10)

        self.nonlin = nn.ReLU()
        self.dropout = nn.Dropout(p=0.0)

    def forward(self, inputs):
        x = torch.flatten(inputs, 1)
        for fc_layer in self.fc_layers:
            x = fc_layer(x)
            x = self.nonlin(x)
            x = self.dropout(x)

        pred_logits = self.last_layer(x)
        outputs = F.log_softmax(pred_logits, dim=1)
        return outputs


# CONFIGURABLE VARIABLES FOR THIS SCRIPT
# Can optionally move these arguments to a params file and configure from there.
MODEL_DIR = "./"
COMPILE_ONLY = False
VALIDATE_ONLY = False

CKPT_STEPS = 5

# Checkpoint-related configurations
CHECKPOINT_STEPS = 5
CHECKPOINT_PATH_EVAL = None

def main_eval_loop(cs_config: cstorch.utils.CSConfig):
    """Main evaluation loop for the MNIST model."""

    backend = cstorch.backend(
        "CSX",
        artifact_dir=MODEL_DIR,
        compile_dir="./compile_dir",
        compile_only=COMPILE_ONLY,
        validate_only=VALIDATE_ONLY,
    )

    with backend.device:
        model = MNISTModel()

    model = cstorch.compile(model, backend)

    def load_checkpoint(checkpoint_path):
        state_dict = cstorch.load(checkpoint_path)
        model.load_state_dict(state_dict["model"])

        global_step = state_dict.get("global_step", 0)
        return global_step

    global_step = 0

    if CHECKPOINT_PATH_EVAL is not None:
        global_step = load_checkpoint(CHECKPOINT_PATH_EVAL)
    else:
        logging.info(
            f"No checkpoint was provided, model parameters will be "
            f"initialized randomly"
        )

    writer = cstorch.utils.tensorboard.SummaryWriter(log_dir=os.path.join(MODEL_DIR, "eval"))

    # Define the accuracy use by the model for evaluation.
    # This example shows two different eval metrics being used,
    # accuracy and perplexity. NOTE: For a complete list of eval metrics
    # available in the experimental API, please see
    # https://docs.cerebras.net/en/1.8.0/wsc/port/porting-pytorch-to-cs/cstorch-api.html#evaluation-metrics
    accuracy = cstorch.metrics.AccuracyMetric("accuracy",)
    perplexity = cstorch.metrics.PerplexityMetric("perplexity",)

    # Define loss function for FC MNIST Model
    loss_fn = torch.nn.NLLLoss()

    @cstorch.compile_step
    def eval_step(batch):
        inputs, targets = batch
        outputs = model(inputs).to(torch.float16)
        loss = loss_fn(outputs, targets)

        accuracy(
            labels=targets.clone(), predictions=outputs.argmax(-1).int(),
        )
        perplexity(labels=targets.clone(), loss=loss)

        return loss

    total_loss = 0
    total_steps = 0

    @cstorch.step_closure
    def post_eval_step(loss: torch.Tensor):
        nonlocal total_loss
        nonlocal total_steps

        logging.info(
            f"| Eval: {model.device} "
            f"Step={global_step}, "
            f"Loss={loss.item():.5f}"
        )

        if torch.isnan(loss).any().item():
            raise ValueError("NaN loss detected.")
        if torch.isinf(loss).any().item():
            raise ValueError("inf loss detected.")

        total_loss += loss.item()
        total_steps += 1

        cstorch.summarize_scalar("loss", loss)

    # Perform evaluation loops
    batch_size = 4
    dataloader = cstorch.utils.data.DataLoader(input_fn_eval, batch_size)
    executor = cstorch.utils.data.DataExecutor(
        dataloader,
        num_steps=TRAINING_STEPS,
        checkpoint_steps=CHECKPOINT_STEPS,
        writer=writer,
        cs_config=cs_config,
    )

    for _, batch in enumerate(executor):
        loss = eval_step(batch)

        global_step += 1

        post_eval_step(loss)

    writer.add_scalar(f"Eval Accuracy", float(accuracy), global_step)
    writer.add_scalar(f"Eval Perplexity", float(perplexity), global_step)

if __name__ == "__main__":

    logging.getLogger().setLevel(logging.INFO)

    os.makedirs(os.path.join(os.getcwd(),'mnist_dataset'), exist_ok=True)

    cs_config = cstorch.utils.CSConfig(
        mount_dirs=[os.getcwd()],
        python_paths=[os.getcwd()],
        max_wgt_servers=1,
        num_workers_per_csx=1,
        max_act_per_csx=1,
    )

    main_eval_loop(cs_config)