Source code for common.pytorch.sparsity.finalizer

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python
"""
Sideband sparse training on CS2 uses an inband representation for pruned weights
which needs to be finalized before running training or inference on another
device. This module exposes a helper script for finalizing the sparsity.
"""

from typing import Optional

import torch


@torch.no_grad()
def finalize_cs2_sparsity(
    model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None,
) -> dict:
    """
    Given a module loaded from a checkpoint trained on CS2 with sideband
    sparsity, finalize the sparsity into the model's parameters as zeros and
    return the mask representing the sparsity pattern for each sparse parameter.

    Args:
        model: The model whose parameters should be updated to freeze sparsity
        optimizer: If given, the corresponding optimizer states sparsity pattern
        is also frozen

    Returns:
        Dict mapping the parameter names to the sparsity pattern as a bool torch
        tensor, where True values indicates present weights and False represents
        pruned weights.
    """

    masks = {}
    for name, param in model.named_parameters():
        mask = param.isfinite()
        if mask.any():
            masks[name] = mask
        param[~mask] = 0
        if optimizer:
            for state in optimizer.state[param].values():
                if (
                    isinstance(state, torch.Tensor)
                    and state.shape == param.shape
                ):
                    state[~mask] = 0
    return masks


[docs]def finalize_cs2_sparsity_checkpoint(state_dict: dict): """ Given a state_dict trained on CS2 with sideband sparsity, finalize the sparsity of all tensors, both weights and optimizer state by replacing the inband pruned weight representation with zeros for use in dense training or evaluation. Args: state_dict: state_dict to finalize the sparsity pattern in. Modified. """ from collections.abc import Iterable def apply(obj): if isinstance(obj, torch.Tensor): obj[obj.isnan()] = 0 elif isinstance(obj, dict): for v in obj.values(): apply(v) elif isinstance(obj, Iterable): for v in obj: apply(v) apply(state_dict)
if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description=""" Given a checkpoint trained on CS2 with sideband sparsity, finalize the sparsity of all tensors, both weights and optimizer state by replacing the inband pruned weight representation with zeros for use in dense training or evaluation. """ ) parser.add_argument( "-i", "--input", required=True, help="Path to input checkpoint" ) parser.add_argument( "-o", "--output", required=True, help="Path to output checkpoint" ) args = parser.parse_args() from modelzoo.common.pytorch import cbtorch state_dict = cbtorch.load(args.input) finalize_cs2_sparsity_checkpoint(state_dict) cbtorch.save(state_dict, args.output)