# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Base class for all dynamic sparsity optimizer, plus dynamic schedule helpers.
"""
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Optional, Union
import torch
import cerebras.pytorch as cstorch
from cerebras.pytorch.utils.weak import DefaultWeakIdKeyDictionary
from .base import SparsityAlgorithm
from .utils import UpdateScheduleType, make_update_schedule
[docs]class DynamicSparsityAlgorithm(SparsityAlgorithm, ABC):
def __init__(
self,
sparsity: Union[float, dict] = None,
update: Optional[UpdateScheduleType] = None,
add_summaries: bool = False,
**kwargs,
):
"""
Args:
sparsity: A float specifying the level of sparsity to apply to each
parameter or a dictionary specifying the schedule to use for
sparsity. The dictionary must have a "type" key, which specifies
the type of schedule to use. The remaining keys are
schedule-specific. The following schedule types
are supported:
- ":py:class:`constant <cerebras.pytorch.sparse.utils.Constant>`"
- ":py:class:`linear <cerebras.pytorch.sparse.utils.Linear>`"
- ":py:class:`exp <cerebras.pytorch.sparse.utils.Exp>`"
- ":py:class:`power <cerebras.pytorch.sparse.utils.Power>`"
- ":py:class:`cosine <cerebras.pytorch.sparse.utils.Cosine>`"
- ":py:class:`cycling <cerebras.pytorch.sparse.utils.Cycling>`"
update: A dictionary specifying the schedule to use for updating the sparsity pattern.
The dictionary must contain keys that can be used to construct either a
:py:class:`~cerebras.pytorch.sparse.utils.FreqSchedule` or a
:py:class:`~cerebras.pytorch.sparse.utils.ListSchedule`.
If not provided, the sparsity pattern will be updated every step.
add_summaries: Whether to add summaries for the sparsity patterns
"""
self.add_summaries = add_summaries
super().__init__(sparsity=sparsity, **kwargs)
self.update_schedule = make_update_schedule(update)
self.starts_sparse = self.update_schedule(
torch.tensor(0, dtype=torch.int64)
)
if not self.starts_sparse:
self.init_method = lambda p, sparsity, **kwargs: cstorch.ones_like(
p, dtype=torch.bool
)
self.step = torch.tensor(0, dtype=torch.int64)
def csx_annotate_sparsity(self, param: "SparseParameter"):
if cstorch.use_cs():
begin_step = getattr(self.update_schedule, "start", None) or 0
# If the schedule has a `stop` step use that, otherwise pick
# 100,000 arbitrarily.
end_step = getattr(self.update_schedule, "stop", None) or 100000
# This simple scalar computation does not need to be traced
with torch.device("cpu"):
min_max_end = self.sparsity[param.data].get_min_max_end(
begin_step, end_step
)
if min_max_end and not self.starts_sparse:
# If we we don't start sparse, there is a period of dense
# training, or 0% sparsity.
_, max_v, end_v = min_max_end
min_max_end = (0.0, max_v, end_v)
min_v, max_v, ending_v = min_max_end
param.annotate("min_sparsity", min_v)
param.annotate("max_sparsity", max_v)
param.annotate("sparsity", ending_v)
def sparsify_parameter(
self, module: torch.nn.Module, name: str, param: torch.Tensor
) -> None:
super().sparsify_parameter(module, name, param)
self.sparsity[param].update(self.starts_sparse)
@cached_property
def is_update_step(self):
"""
Returns True if the current step is an update step according to the
update schedule.
"""
return self.update_schedule(self.step)
@torch.no_grad()
def update(self, optimizer: Optional[torch.optim.Optimizer] = None):
# Ensure we've called apply_sparsity before step
self._ensure_sparsity_applied()
# The weights and optimizer state were just updated. In case we
# _decrease_ sparsity here instead of increasing it, prune the weights
# using the current weight masks
self.prune_weights()
cstorch.amp.update_if_finite(optimizer, self.step)
self.step += 1
if self.add_summaries:
# Collect all sparse params that a given sparsity schedule is assigned to
unique_schedules = DefaultWeakIdKeyDictionary(list)
for sparse_param in self.sparse_params.values():
unique_schedules[self.sparsity[sparse_param.param]].append(
sparse_param.name
)
if optimizer:
if not isinstance(optimizer, cstorch.optim.Optimizer):
raise TypeError(
f"Expected a Cerebras Optimizer. Got: {type(optimizer)}"
)
# Only should update if no optimizer gradients are NaN/inf
isfinite = cstorch.amp.isfinite(optimizer)
if isinstance(isfinite, torch.Tensor):
self.is_update_step &= isfinite
for sparse_param in self.sparse_params.values():
p = sparse_param.param
mask = sparse_param.mask
if p.grad is None:
# If the gradient is None, then the parameter was not updated
# so there is no need to update the mask
continue
schedule = self.sparsity[p]
# Compute sparsity level for the parameter at the current step
sparsity = schedule(self.step).to(p.device)
# Ensure dynamic sparsity stays between [0, 1)
sparsity = torch.clamp(sparsity, min=0.0, max=1.0)
# update the sparsity schedule if it is an update step
# so that we get the latest sparsity value
# This can technically cause issues if there are multiple optimizers
# sparsified by the same sparsity algorithm
schedule.update(self.is_update_step)
if self.add_summaries and schedule in unique_schedules:
# We only want to summarize this once per unique schedule
names = unique_schedules.pop(schedule)
# Create a "glob" using the common prefix, e.g.
# [fc1.weight, fc2.weight] would yield "fc*"
name_glob = os.path.commonprefix(names) + "*"
cstorch.summarize_scalar(
f"sparsity/{name_glob}/target", sparsity
)
new_mask = self.update_mask(p, mask, sparsity)
# Rewrite into the existing mask tensor for state tracking
new_mask = torch.where(self.is_update_step, new_mask, mask)
sparse_param.mask = new_mask
if self.add_summaries:
cstorch.summarize_scalar(
f"sparsity/{sparse_param.name}/actual",
1 - new_mask.sum() / new_mask.numel(),
)
# Clear the scheduler's cache. We don't want to carry this over
# to the next iteration.
# Technically this can be handled by not caching the values
# and relying on common subexpression elimination, but for now
# keep as is
for sparsity in self.sparsity.values():
sparsity.cache_clear()
# Clear update step cache. We don't want to carry this over
# to the next iteration.
self.__dict__.pop("is_update_step", None)
# We need to reapply the masks here one more time in order for
# the compiler to pick up that the masks were updated.
self.prune_weights()
@abstractmethod
@torch.no_grad()
def update_mask(self, p, mask, sparsity) -> torch.Tensor:
"""
Compute an updated sparsity pattern.
Args:
p (torch.Tensor): the parameter to sparsify
mask (torch.tensor(dtype=torch.bool)): the current mask
of param p
sparsity (torch.tensor(dtype=torch.float32)): the desired
sparsity level
Returns:
The updated sparsity pattern on parameter p
"""
def visit_state(self, f):
super().visit_state(f)
out = f(self.step)
if out is not None:
self.step = out
# Iterate a unique list of sparsity hyperparam objects
for sparsity in torch.utils.weak.WeakIdKeyDictionary(
{
self.sparsity[sparse_param.param]: None
for sparse_param in self.sparse_params.values()
}
):
sparsity.visit_state(f)
def state_dict(self):
state_dict = super().state_dict()
state_dict["step"] = self.step
state_dict["sparsity"] = {
name: s
# Only need to save unique sparsity schedules
for sparsity, name in torch.utils.weak.WeakIdKeyDictionary(
{
self.sparsity[sparse_param.param]: sparse_param.name
for sparse_param in self.sparse_params.values()
}
).items()
if (s := sparsity.state_dict())
}
return state_dict
def load_state_dict(self, state_dict):
self.step = state_dict.pop("step")
super().load_state_dict(state_dict)
state_dict["sparsity"] = {}
for sparse_param in self.sparse_params.values():
sparsity = self.sparsity[sparse_param.param]
if s := state_dict["sparsity"].get(sparse_param.name):
sparsity.load_state_dict(s)
with self._backend.device:
self.visit_state(lambda x: x.to(self._backend.torch_device))